From 69ca4df6726310d986bf1af9863f6b6017d6585c Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 30 Mar 2021 16:09:28 -0400 Subject: [PATCH 01/36] test --- .pre-commit-config.yaml | 2 +- docs/aiplatform_v1beta1/metadata_service.rst | 11 + docs/aiplatform_v1beta1/services.rst | 1 + docs/conf.py | 6 +- .../v1/schema/predict/instance/__init__.py | 54 +- .../v1/schema/predict/instance_v1/__init__.py | 18 +- .../predict/instance_v1/types/__init__.py | 54 +- .../instance_v1/types/image_classification.py | 6 +- .../types/image_object_detection.py | 6 +- .../instance_v1/types/image_segmentation.py | 6 +- .../instance_v1/types/text_classification.py | 6 +- .../instance_v1/types/text_extraction.py | 6 +- .../instance_v1/types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 6 +- .../instance_v1/types/video_classification.py | 6 +- .../types/video_object_tracking.py | 6 +- .../v1/schema/predict/params/__init__.py | 36 +- .../v1/schema/predict/params_v1/__init__.py | 12 +- .../predict/params_v1/types/__init__.py | 36 +- .../params_v1/types/image_classification.py | 6 +- .../params_v1/types/image_object_detection.py | 6 +- .../params_v1/types/image_segmentation.py | 6 +- .../types/video_action_recognition.py | 6 +- .../params_v1/types/video_classification.py | 6 +- .../params_v1/types/video_object_tracking.py | 6 +- .../v1/schema/predict/prediction/__init__.py | 60 +- .../schema/predict/prediction_v1/__init__.py | 20 +- .../predict/prediction_v1/types/__init__.py | 60 +- .../prediction_v1/types/classification.py | 6 +- .../types/image_object_detection.py | 10 +- .../prediction_v1/types/image_segmentation.py | 6 +- .../types/tabular_classification.py | 6 +- .../prediction_v1/types/tabular_regression.py | 6 +- .../prediction_v1/types/text_extraction.py | 6 +- .../prediction_v1/types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 18 +- .../types/video_classification.py | 18 +- .../types/video_object_tracking.py | 43 +- .../schema/trainingjob/definition/__init__.py | 150 +- .../trainingjob/definition_v1/__init__.py | 50 +- .../definition_v1/types/__init__.py | 54 +- .../types/automl_image_classification.py | 26 +- .../types/automl_image_object_detection.py | 26 +- .../types/automl_image_segmentation.py | 26 +- .../definition_v1/types/automl_tables.py | 94 +- .../types/automl_text_classification.py | 11 +- .../types/automl_text_extraction.py | 11 +- .../types/automl_text_sentiment.py | 11 +- .../types/automl_video_action_recognition.py | 16 +- .../types/automl_video_classification.py | 16 +- .../types/automl_video_object_tracking.py | 16 +- .../export_evaluated_data_items_config.py | 6 +- .../schema/predict/instance/__init__.py | 54 +- .../predict/instance_v1beta1/__init__.py | 18 +- .../instance_v1beta1/types/__init__.py | 54 +- .../types/image_classification.py | 6 +- .../types/image_object_detection.py | 6 +- .../types/image_segmentation.py | 6 +- .../types/text_classification.py | 6 +- .../instance_v1beta1/types/text_extraction.py | 6 +- .../instance_v1beta1/types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 6 +- .../types/video_classification.py | 6 +- .../types/video_object_tracking.py | 6 +- .../v1beta1/schema/predict/params/__init__.py | 36 +- .../schema/predict/params_v1beta1/__init__.py | 12 +- .../predict/params_v1beta1/types/__init__.py | 36 +- .../types/image_classification.py | 6 +- .../types/image_object_detection.py | 6 +- .../types/image_segmentation.py | 6 +- .../types/video_action_recognition.py | 6 +- .../types/video_classification.py | 6 +- .../types/video_object_tracking.py | 6 +- .../schema/predict/prediction/__init__.py | 60 +- .../predict/prediction_v1beta1/__init__.py | 20 +- .../prediction_v1beta1/types/__init__.py | 60 +- .../types/classification.py | 6 +- .../types/image_object_detection.py | 10 +- .../types/image_segmentation.py | 6 +- .../types/tabular_classification.py | 6 +- .../types/tabular_regression.py | 6 +- .../types/text_extraction.py | 6 +- .../types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 18 +- .../types/video_classification.py | 18 +- .../types/video_object_tracking.py | 43 +- .../schema/trainingjob/definition/__init__.py | 150 +- .../definition_v1beta1/__init__.py | 50 +- .../definition_v1beta1/types/__init__.py | 54 +- .../types/automl_image_classification.py | 26 +- .../types/automl_image_object_detection.py | 26 +- .../types/automl_image_segmentation.py | 26 +- .../definition_v1beta1/types/automl_tables.py | 94 +- .../types/automl_text_classification.py | 11 +- .../types/automl_text_extraction.py | 11 +- .../types/automl_text_sentiment.py | 11 +- .../types/automl_video_action_recognition.py | 16 +- .../types/automl_video_classification.py | 16 +- .../types/automl_video_object_tracking.py | 16 +- .../export_evaluated_data_items_config.py | 6 +- google/cloud/aiplatform_v1/__init__.py | 324 +- .../services/dataset_service/__init__.py | 4 +- .../services/dataset_service/async_client.py | 439 +- .../services/dataset_service/client.py | 543 +- .../services/dataset_service/pagers.py | 113 +- .../dataset_service/transports/__init__.py | 10 +- .../dataset_service/transports/base.py | 237 +- .../dataset_service/transports/grpc.py | 318 +- .../transports/grpc_asyncio.py | 331 +- .../services/endpoint_service/__init__.py | 4 +- .../services/endpoint_service/async_client.py | 331 +- .../services/endpoint_service/client.py | 400 +- .../services/endpoint_service/pagers.py | 45 +- .../endpoint_service/transports/__init__.py | 10 +- .../endpoint_service/transports/base.py | 180 +- .../endpoint_service/transports/grpc.py | 269 +- .../transports/grpc_asyncio.py | 284 +- .../services/job_service/__init__.py | 4 +- .../services/job_service/async_client.py | 794 +- .../services/job_service/client.py | 954 +- .../services/job_service/pagers.py | 157 +- .../job_service/transports/__init__.py | 10 +- .../services/job_service/transports/base.py | 365 +- .../services/job_service/transports/grpc.py | 497 +- .../job_service/transports/grpc_asyncio.py | 516 +- .../services/migration_service/__init__.py | 4 +- .../migration_service/async_client.py | 151 +- .../services/migration_service/client.py | 284 +- .../services/migration_service/pagers.py | 51 +- .../migration_service/transports/__init__.py | 10 +- .../migration_service/transports/base.py | 92 +- .../migration_service/transports/grpc.py | 202 +- .../transports/grpc_asyncio.py | 207 +- .../services/model_service/__init__.py | 4 +- .../services/model_service/async_client.py | 441 +- .../services/model_service/client.py | 555 +- .../services/model_service/pagers.py | 119 +- .../model_service/transports/__init__.py | 10 +- .../services/model_service/transports/base.py | 224 +- .../services/model_service/transports/grpc.py | 318 +- .../model_service/transports/grpc_asyncio.py | 327 +- .../services/pipeline_service/__init__.py | 4 +- .../services/pipeline_service/async_client.py | 249 +- .../services/pipeline_service/client.py | 329 +- .../services/pipeline_service/pagers.py | 51 +- .../pipeline_service/transports/__init__.py | 10 +- .../pipeline_service/transports/base.py | 134 +- .../pipeline_service/transports/grpc.py | 250 +- .../transports/grpc_asyncio.py | 257 +- .../services/prediction_service/__init__.py | 4 +- .../prediction_service/async_client.py | 108 +- .../services/prediction_service/client.py | 166 +- .../prediction_service/transports/__init__.py | 10 +- .../prediction_service/transports/base.py | 84 +- .../prediction_service/transports/grpc.py | 179 +- .../transports/grpc_asyncio.py | 186 +- .../specialist_pool_service/__init__.py | 4 +- .../specialist_pool_service/async_client.py | 264 +- .../specialist_pool_service/client.py | 309 +- .../specialist_pool_service/pagers.py | 51 +- .../transports/__init__.py | 14 +- .../transports/base.py | 135 +- .../transports/grpc.py | 251 +- .../transports/grpc_asyncio.py | 258 +- google/cloud/aiplatform_v1/types/__init__.py | 368 +- .../aiplatform_v1/types/accelerator_type.py | 5 +- .../cloud/aiplatform_v1/types/annotation.py | 21 +- .../aiplatform_v1/types/annotation_spec.py | 13 +- .../types/batch_prediction_job.py | 99 +- .../aiplatform_v1/types/completion_stats.py | 5 +- .../cloud/aiplatform_v1/types/custom_job.py | 86 +- google/cloud/aiplatform_v1/types/data_item.py | 17 +- .../aiplatform_v1/types/data_labeling_job.py | 71 +- google/cloud/aiplatform_v1/types/dataset.py | 32 +- .../aiplatform_v1/types/dataset_service.py | 102 +- .../aiplatform_v1/types/deployed_model_ref.py | 5 +- .../aiplatform_v1/types/encryption_spec.py | 5 +- google/cloud/aiplatform_v1/types/endpoint.py | 36 +- .../aiplatform_v1/types/endpoint_service.py | 68 +- google/cloud/aiplatform_v1/types/env_var.py | 7 +- .../types/hyperparameter_tuning_job.py | 45 +- google/cloud/aiplatform_v1/types/io.py | 12 +- .../cloud/aiplatform_v1/types/job_service.py | 106 +- google/cloud/aiplatform_v1/types/job_state.py | 5 +- .../aiplatform_v1/types/machine_resources.py | 26 +- .../types/manual_batch_tuning_parameters.py | 5 +- .../types/migratable_resource.py | 37 +- .../aiplatform_v1/types/migration_service.py | 87 +- google/cloud/aiplatform_v1/types/model.py | 59 +- .../aiplatform_v1/types/model_evaluation.py | 13 +- .../types/model_evaluation_slice.py | 18 +- .../aiplatform_v1/types/model_service.py | 98 +- google/cloud/aiplatform_v1/types/operation.py | 23 +- .../aiplatform_v1/types/pipeline_service.py | 26 +- .../aiplatform_v1/types/pipeline_state.py | 5 +- .../aiplatform_v1/types/prediction_service.py | 19 +- .../aiplatform_v1/types/specialist_pool.py | 5 +- .../types/specialist_pool_service.py | 46 +- google/cloud/aiplatform_v1/types/study.py | 129 +- .../aiplatform_v1/types/training_pipeline.py | 82 +- .../types/user_action_reference.py | 9 +- google/cloud/aiplatform_v1beta1/__init__.py | 534 +- .../services/dataset_service/__init__.py | 4 +- .../services/dataset_service/async_client.py | 439 +- .../services/dataset_service/client.py | 543 +- .../services/dataset_service/pagers.py | 113 +- .../dataset_service/transports/__init__.py | 10 +- .../dataset_service/transports/base.py | 237 +- .../dataset_service/transports/grpc.py | 318 +- .../transports/grpc_asyncio.py | 331 +- .../services/endpoint_service/__init__.py | 4 +- .../services/endpoint_service/async_client.py | 331 +- .../services/endpoint_service/client.py | 400 +- .../services/endpoint_service/pagers.py | 45 +- .../endpoint_service/transports/__init__.py | 10 +- .../endpoint_service/transports/base.py | 180 +- .../endpoint_service/transports/grpc.py | 269 +- .../transports/grpc_asyncio.py | 284 +- .../services/job_service/__init__.py | 4 +- .../services/job_service/async_client.py | 1513 ++- .../services/job_service/client.py | 1711 +++- .../services/job_service/pagers.py | 403 +- .../job_service/transports/__init__.py | 10 +- .../services/job_service/transports/base.py | 483 +- .../services/job_service/transports/grpc.py | 726 +- .../job_service/transports/grpc_asyncio.py | 745 +- .../services/metadata_service/__init__.py | 24 + .../services/metadata_service/async_client.py | 2487 +++++ .../services/metadata_service/client.py | 2717 ++++++ .../services/metadata_service/pagers.py | 635 ++ .../metadata_service/transports/__init__.py | 35 + .../metadata_service/transports/base.py | 480 + .../metadata_service/transports/grpc.py | 917 ++ .../transports/grpc_asyncio.py | 922 ++ .../services/migration_service/__init__.py | 4 +- .../migration_service/async_client.py | 151 +- .../services/migration_service/client.py | 284 +- .../services/migration_service/pagers.py | 51 +- .../migration_service/transports/__init__.py | 10 +- .../migration_service/transports/base.py | 92 +- .../migration_service/transports/grpc.py | 202 +- .../transports/grpc_asyncio.py | 207 +- .../services/model_service/__init__.py | 4 +- .../services/model_service/async_client.py | 441 +- .../services/model_service/client.py | 555 +- .../services/model_service/pagers.py | 119 +- .../model_service/transports/__init__.py | 10 +- .../services/model_service/transports/base.py | 228 +- .../services/model_service/transports/grpc.py | 318 +- .../model_service/transports/grpc_asyncio.py | 327 +- .../services/pipeline_service/__init__.py | 4 +- .../services/pipeline_service/async_client.py | 253 +- .../services/pipeline_service/client.py | 333 +- .../services/pipeline_service/pagers.py | 51 +- .../pipeline_service/transports/__init__.py | 10 +- .../pipeline_service/transports/base.py | 138 +- .../pipeline_service/transports/grpc.py | 254 +- .../transports/grpc_asyncio.py | 261 +- .../services/prediction_service/__init__.py | 4 +- .../prediction_service/async_client.py | 148 +- .../services/prediction_service/client.py | 206 +- .../prediction_service/transports/__init__.py | 10 +- .../prediction_service/transports/base.py | 103 +- .../prediction_service/transports/grpc.py | 195 +- .../transports/grpc_asyncio.py | 203 +- .../specialist_pool_service/__init__.py | 4 +- .../specialist_pool_service/async_client.py | 264 +- .../specialist_pool_service/client.py | 309 +- .../specialist_pool_service/pagers.py | 51 +- .../transports/__init__.py | 14 +- .../transports/base.py | 135 +- .../transports/grpc.py | 251 +- .../transports/grpc_asyncio.py | 258 +- .../services/vizier_service/__init__.py | 4 +- .../services/vizier_service/async_client.py | 554 +- .../services/vizier_service/client.py | 636 +- .../services/vizier_service/pagers.py | 79 +- .../vizier_service/transports/__init__.py | 10 +- .../vizier_service/transports/base.py | 306 +- .../vizier_service/transports/grpc.py | 384 +- .../vizier_service/transports/grpc_asyncio.py | 398 +- .../aiplatform_v1beta1/types/__init__.py | 600 +- .../types/accelerator_type.py | 5 +- .../aiplatform_v1beta1/types/annotation.py | 21 +- .../types/annotation_spec.py | 13 +- .../aiplatform_v1beta1/types/artifact.py | 149 + .../types/batch_prediction_job.py | 113 +- .../types/completion_stats.py | 5 +- .../cloud/aiplatform_v1beta1/types/context.py | 133 + .../aiplatform_v1beta1/types/custom_job.py | 78 +- .../aiplatform_v1beta1/types/data_item.py | 17 +- .../types/data_labeling_job.py | 71 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 32 +- .../types/dataset_service.py | 102 +- .../types/deployed_model_ref.py | 5 +- .../types/encryption_spec.py | 5 +- .../aiplatform_v1beta1/types/endpoint.py | 40 +- .../types/endpoint_service.py | 68 +- .../cloud/aiplatform_v1beta1/types/env_var.py | 5 +- .../cloud/aiplatform_v1beta1/types/event.py | 86 + .../aiplatform_v1beta1/types/execution.py | 145 + .../aiplatform_v1beta1/types/explanation.py | 104 +- .../types/explanation_metadata.py | 72 +- .../types/feature_monitoring_stats.py | 108 + .../types/hyperparameter_tuning_job.py | 45 +- google/cloud/aiplatform_v1beta1/types/io.py | 12 +- .../aiplatform_v1beta1/types/job_service.py | 398 +- .../aiplatform_v1beta1/types/job_state.py | 5 +- .../types/lineage_subgraph.py | 61 + .../types/machine_resources.py | 36 +- .../types/manual_batch_tuning_parameters.py | 6 +- .../types/metadata_schema.py | 87 + .../types/metadata_service.py | 900 ++ .../types/metadata_store.py | 69 + .../types/migratable_resource.py | 37 +- .../types/migration_service.py | 87 +- .../cloud/aiplatform_v1beta1/types/model.py | 63 +- .../types/model_deployment_monitoring_job.py | 361 + .../types/model_evaluation.py | 26 +- .../types/model_evaluation_slice.py | 18 +- .../types/model_monitoring.py | 219 + .../aiplatform_v1beta1/types/model_service.py | 98 +- .../aiplatform_v1beta1/types/operation.py | 23 +- .../types/pipeline_service.py | 30 +- .../types/pipeline_state.py | 5 +- .../types/prediction_service.py | 42 +- .../types/specialist_pool.py | 5 +- .../types/specialist_pool_service.py | 46 +- .../cloud/aiplatform_v1beta1/types/study.py | 193 +- .../types/training_pipeline.py | 82 +- .../types/user_action_reference.py | 14 +- .../types/vizier_service.py | 98 +- noxfile.py | 72 +- tests/unit/gapic/aiplatform_v1/__init__.py | 1 + .../aiplatform_v1/test_dataset_service.py | 2265 +++-- .../aiplatform_v1/test_endpoint_service.py | 1610 ++-- .../gapic/aiplatform_v1/test_job_service.py | 3685 +++---- .../aiplatform_v1/test_migration_service.py | 952 +- .../gapic/aiplatform_v1/test_model_service.py | 2366 ++--- .../aiplatform_v1/test_pipeline_service.py | 1283 ++- .../test_specialist_pool_service.py | 1156 +-- .../unit/gapic/aiplatform_v1beta1/__init__.py | 1 + .../test_dataset_service.py | 2269 +++-- .../test_endpoint_service.py | 1614 ++-- .../aiplatform_v1beta1/test_job_service.py | 6304 ++++++++---- .../test_metadata_service.py | 8524 +++++++++++++++++ .../test_migration_service.py | 962 +- .../aiplatform_v1beta1/test_model_service.py | 2370 ++--- .../test_pipeline_service.py | 1291 ++- .../test_specialist_pool_service.py | 1156 +-- .../aiplatform_v1beta1/test_vizier_service.py | 2605 +++-- 351 files changed, 57432 insertions(+), 32210 deletions(-) create mode 100644 docs/aiplatform_v1beta1/metadata_service.rst create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/types/artifact.py create mode 100644 google/cloud/aiplatform_v1beta1/types/context.py create mode 100644 google/cloud/aiplatform_v1beta1/types/event.py create mode 100644 google/cloud/aiplatform_v1beta1/types/execution.py create mode 100644 google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py create mode 100644 google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py create mode 100644 google/cloud/aiplatform_v1beta1/types/metadata_schema.py create mode 100644 google/cloud/aiplatform_v1beta1/types/metadata_service.py create mode 100644 google/cloud/aiplatform_v1beta1/types/metadata_store.py create mode 100644 google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py create mode 100644 google/cloud/aiplatform_v1beta1/types/model_monitoring.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9024b15d7..32302e4883 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,6 @@ repos: hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.0 hooks: - id: flake8 diff --git a/docs/aiplatform_v1beta1/metadata_service.rst b/docs/aiplatform_v1beta1/metadata_service.rst new file mode 100644 index 0000000000..c1ebfa9585 --- /dev/null +++ b/docs/aiplatform_v1beta1/metadata_service.rst @@ -0,0 +1,11 @@ +MetadataService +--------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.metadata_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.metadata_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index 6e4f84c707..95202b1e99 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -6,6 +6,7 @@ Services for Google Cloud Aiplatform v1beta1 API dataset_service endpoint_service job_service + metadata_service migration_service model_service pipeline_service diff --git a/docs/conf.py b/docs/conf.py index 98e68be241..c05116a68c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -347,9 +347,13 @@ intersphinx_mapping = { "python": ("https://python.readthedocs.org/en/latest/", None), "google-auth": ("https://googleapis.dev/python/google-auth/latest/", None), - "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,), + "google.api_core": ( + "https://googleapis.dev/python/google-api-core/latest/", + None, + ), "grpc": ("https://grpc.github.io/grpc/python/", None), "proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None), + } diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py index fb2668afb5..e99be5a9d2 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py @@ -15,42 +15,24 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_classification import ( - ImageClassificationPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_object_detection import ( - ImageObjectDetectionPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_segmentation import ( - ImageSegmentationPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_classification import ( - TextClassificationPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_extraction import ( - TextExtractionPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_sentiment import ( - TextSentimentPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_action_recognition import ( - VideoActionRecognitionPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_classification import ( - VideoClassificationPredictionInstance, -) -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_object_tracking import ( - VideoObjectTrackingPredictionInstance, -) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_classification import ImageClassificationPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_object_detection import ImageObjectDetectionPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_segmentation import ImageSegmentationPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_classification import TextClassificationPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_extraction import TextExtractionPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_sentiment import TextSentimentPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_action_recognition import VideoActionRecognitionPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_classification import VideoClassificationPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_object_tracking import VideoObjectTrackingPredictionInstance __all__ = ( - "ImageClassificationPredictionInstance", - "ImageObjectDetectionPredictionInstance", - "ImageSegmentationPredictionInstance", - "TextClassificationPredictionInstance", - "TextExtractionPredictionInstance", - "TextSentimentPredictionInstance", - "VideoActionRecognitionPredictionInstance", - "VideoClassificationPredictionInstance", - "VideoObjectTrackingPredictionInstance", + 'ImageClassificationPredictionInstance', + 'ImageObjectDetectionPredictionInstance', + 'ImageSegmentationPredictionInstance', + 'TextClassificationPredictionInstance', + 'TextExtractionPredictionInstance', + 'TextSentimentPredictionInstance', + 'VideoActionRecognitionPredictionInstance', + 'VideoClassificationPredictionInstance', + 'VideoObjectTrackingPredictionInstance', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py index f6d9a128ad..c68b05e778 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py @@ -27,13 +27,13 @@ __all__ = ( - "ImageObjectDetectionPredictionInstance", - "ImageSegmentationPredictionInstance", - "TextClassificationPredictionInstance", - "TextExtractionPredictionInstance", - "TextSentimentPredictionInstance", - "VideoActionRecognitionPredictionInstance", - "VideoClassificationPredictionInstance", - "VideoObjectTrackingPredictionInstance", - "ImageClassificationPredictionInstance", + 'ImageObjectDetectionPredictionInstance', + 'ImageSegmentationPredictionInstance', + 'TextClassificationPredictionInstance', + 'TextExtractionPredictionInstance', + 'TextSentimentPredictionInstance', + 'VideoActionRecognitionPredictionInstance', + 'VideoClassificationPredictionInstance', + 'VideoObjectTrackingPredictionInstance', +'ImageClassificationPredictionInstance', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py index 041fe6cdb1..aacf581e2e 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py @@ -15,24 +15,42 @@ # limitations under the License. # -from .image_classification import ImageClassificationPredictionInstance -from .image_object_detection import ImageObjectDetectionPredictionInstance -from .image_segmentation import ImageSegmentationPredictionInstance -from .text_classification import TextClassificationPredictionInstance -from .text_extraction import TextExtractionPredictionInstance -from .text_sentiment import TextSentimentPredictionInstance -from .video_action_recognition import VideoActionRecognitionPredictionInstance -from .video_classification import VideoClassificationPredictionInstance -from .video_object_tracking import VideoObjectTrackingPredictionInstance +from .image_classification import ( + ImageClassificationPredictionInstance, +) +from .image_object_detection import ( + ImageObjectDetectionPredictionInstance, +) +from .image_segmentation import ( + ImageSegmentationPredictionInstance, +) +from .text_classification import ( + TextClassificationPredictionInstance, +) +from .text_extraction import ( + TextExtractionPredictionInstance, +) +from .text_sentiment import ( + TextSentimentPredictionInstance, +) +from .video_action_recognition import ( + VideoActionRecognitionPredictionInstance, +) +from .video_classification import ( + VideoClassificationPredictionInstance, +) +from .video_object_tracking import ( + VideoObjectTrackingPredictionInstance, +) __all__ = ( - "ImageClassificationPredictionInstance", - "ImageObjectDetectionPredictionInstance", - "ImageSegmentationPredictionInstance", - "TextClassificationPredictionInstance", - "TextExtractionPredictionInstance", - "TextSentimentPredictionInstance", - "VideoActionRecognitionPredictionInstance", - "VideoClassificationPredictionInstance", - "VideoObjectTrackingPredictionInstance", + 'ImageClassificationPredictionInstance', + 'ImageObjectDetectionPredictionInstance', + 'ImageSegmentationPredictionInstance', + 'TextClassificationPredictionInstance', + 'TextExtractionPredictionInstance', + 'TextSentimentPredictionInstance', + 'VideoActionRecognitionPredictionInstance', + 'VideoClassificationPredictionInstance', + 'VideoObjectTrackingPredictionInstance', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py index b5fa9b4dbf..2b7e94a11b 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"ImageClassificationPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'ImageClassificationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py index 45752ce7e2..a7ad135173 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"ImageObjectDetectionPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'ImageObjectDetectionPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py index cb436d7029..fb663cb849 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"ImageSegmentationPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'ImageSegmentationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py index ceff5308b7..1d54c594d9 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"TextClassificationPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'TextClassificationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py index 2e96216466..6260e4eca9 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"TextExtractionPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'TextExtractionPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py index 37353ad806..ca47c08fc2 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"TextSentimentPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'TextSentimentPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py index 6de5665312..5e72ebbeae 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"VideoActionRecognitionPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'VideoActionRecognitionPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py index ab7c0edfe1..2a302fc41f 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"VideoClassificationPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'VideoClassificationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py index f797f58f4e..7f1d7b371b 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.instance", - manifest={"VideoObjectTrackingPredictionInstance",}, + package='google.cloud.aiplatform.v1.schema.predict.instance', + manifest={ + 'VideoObjectTrackingPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py index c046f4d7e5..7a3e372796 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py @@ -15,30 +15,18 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_classification import ( - ImageClassificationPredictionParams, -) -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_object_detection import ( - ImageObjectDetectionPredictionParams, -) -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_segmentation import ( - ImageSegmentationPredictionParams, -) -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_action_recognition import ( - VideoActionRecognitionPredictionParams, -) -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_classification import ( - VideoClassificationPredictionParams, -) -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_object_tracking import ( - VideoObjectTrackingPredictionParams, -) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_classification import ImageClassificationPredictionParams +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_object_detection import ImageObjectDetectionPredictionParams +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_segmentation import ImageSegmentationPredictionParams +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_action_recognition import VideoActionRecognitionPredictionParams +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_classification import VideoClassificationPredictionParams +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_object_tracking import VideoObjectTrackingPredictionParams __all__ = ( - "ImageClassificationPredictionParams", - "ImageObjectDetectionPredictionParams", - "ImageSegmentationPredictionParams", - "VideoActionRecognitionPredictionParams", - "VideoClassificationPredictionParams", - "VideoObjectTrackingPredictionParams", + 'ImageClassificationPredictionParams', + 'ImageObjectDetectionPredictionParams', + 'ImageSegmentationPredictionParams', + 'VideoActionRecognitionPredictionParams', + 'VideoClassificationPredictionParams', + 'VideoObjectTrackingPredictionParams', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py index 79fb1c2097..0e358981b3 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py @@ -24,10 +24,10 @@ __all__ = ( - "ImageObjectDetectionPredictionParams", - "ImageSegmentationPredictionParams", - "VideoActionRecognitionPredictionParams", - "VideoClassificationPredictionParams", - "VideoObjectTrackingPredictionParams", - "ImageClassificationPredictionParams", + 'ImageObjectDetectionPredictionParams', + 'ImageSegmentationPredictionParams', + 'VideoActionRecognitionPredictionParams', + 'VideoClassificationPredictionParams', + 'VideoObjectTrackingPredictionParams', +'ImageClassificationPredictionParams', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py index 2f2c29bba5..4f53fda062 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py @@ -15,18 +15,30 @@ # limitations under the License. # -from .image_classification import ImageClassificationPredictionParams -from .image_object_detection import ImageObjectDetectionPredictionParams -from .image_segmentation import ImageSegmentationPredictionParams -from .video_action_recognition import VideoActionRecognitionPredictionParams -from .video_classification import VideoClassificationPredictionParams -from .video_object_tracking import VideoObjectTrackingPredictionParams +from .image_classification import ( + ImageClassificationPredictionParams, +) +from .image_object_detection import ( + ImageObjectDetectionPredictionParams, +) +from .image_segmentation import ( + ImageSegmentationPredictionParams, +) +from .video_action_recognition import ( + VideoActionRecognitionPredictionParams, +) +from .video_classification import ( + VideoClassificationPredictionParams, +) +from .video_object_tracking import ( + VideoObjectTrackingPredictionParams, +) __all__ = ( - "ImageClassificationPredictionParams", - "ImageObjectDetectionPredictionParams", - "ImageSegmentationPredictionParams", - "VideoActionRecognitionPredictionParams", - "VideoClassificationPredictionParams", - "VideoObjectTrackingPredictionParams", + 'ImageClassificationPredictionParams', + 'ImageObjectDetectionPredictionParams', + 'ImageSegmentationPredictionParams', + 'VideoActionRecognitionPredictionParams', + 'VideoClassificationPredictionParams', + 'VideoObjectTrackingPredictionParams', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py index 3a9efd0ea2..b29f91c772 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.params", - manifest={"ImageClassificationPredictionParams",}, + package='google.cloud.aiplatform.v1.schema.predict.params', + manifest={ + 'ImageClassificationPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py index c37507a4e0..7b34fe0395 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.params", - manifest={"ImageObjectDetectionPredictionParams",}, + package='google.cloud.aiplatform.v1.schema.predict.params', + manifest={ + 'ImageObjectDetectionPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py index 108cff107b..3b2f2c3ff2 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.params", - manifest={"ImageSegmentationPredictionParams",}, + package='google.cloud.aiplatform.v1.schema.predict.params', + manifest={ + 'ImageSegmentationPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py index 66f1f19e76..9fbd7a6b6a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.params", - manifest={"VideoActionRecognitionPredictionParams",}, + package='google.cloud.aiplatform.v1.schema.predict.params', + manifest={ + 'VideoActionRecognitionPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py index bfe8df9f5c..cf79e22d5f 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.params", - manifest={"VideoClassificationPredictionParams",}, + package='google.cloud.aiplatform.v1.schema.predict.params', + manifest={ + 'VideoClassificationPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py index 899de1050a..1b1b615d0a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.params", - manifest={"VideoObjectTrackingPredictionParams",}, + package='google.cloud.aiplatform.v1.schema.predict.params', + manifest={ + 'VideoObjectTrackingPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py index d8e2b782c2..01d2f8177a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py @@ -15,46 +15,26 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.classification import ( - ClassificationPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_object_detection import ( - ImageObjectDetectionPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_segmentation import ( - ImageSegmentationPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_classification import ( - TabularClassificationPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_regression import ( - TabularRegressionPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_extraction import ( - TextExtractionPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_sentiment import ( - TextSentimentPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_action_recognition import ( - VideoActionRecognitionPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_classification import ( - VideoClassificationPredictionResult, -) -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_object_tracking import ( - VideoObjectTrackingPredictionResult, -) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.classification import ClassificationPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_object_detection import ImageObjectDetectionPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_segmentation import ImageSegmentationPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_classification import TabularClassificationPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_regression import TabularRegressionPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_extraction import TextExtractionPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_sentiment import TextSentimentPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_action_recognition import VideoActionRecognitionPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_classification import VideoClassificationPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_object_tracking import VideoObjectTrackingPredictionResult __all__ = ( - "ClassificationPredictionResult", - "ImageObjectDetectionPredictionResult", - "ImageSegmentationPredictionResult", - "TabularClassificationPredictionResult", - "TabularRegressionPredictionResult", - "TextExtractionPredictionResult", - "TextSentimentPredictionResult", - "VideoActionRecognitionPredictionResult", - "VideoClassificationPredictionResult", - "VideoObjectTrackingPredictionResult", + 'ClassificationPredictionResult', + 'ImageObjectDetectionPredictionResult', + 'ImageSegmentationPredictionResult', + 'TabularClassificationPredictionResult', + 'TabularRegressionPredictionResult', + 'TextExtractionPredictionResult', + 'TextSentimentPredictionResult', + 'VideoActionRecognitionPredictionResult', + 'VideoClassificationPredictionResult', + 'VideoObjectTrackingPredictionResult', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py index 91fae5a3b1..42f26f575f 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py @@ -28,14 +28,14 @@ __all__ = ( - "ImageObjectDetectionPredictionResult", - "ImageSegmentationPredictionResult", - "TabularClassificationPredictionResult", - "TabularRegressionPredictionResult", - "TextExtractionPredictionResult", - "TextSentimentPredictionResult", - "VideoActionRecognitionPredictionResult", - "VideoClassificationPredictionResult", - "VideoObjectTrackingPredictionResult", - "ClassificationPredictionResult", + 'ImageObjectDetectionPredictionResult', + 'ImageSegmentationPredictionResult', + 'TabularClassificationPredictionResult', + 'TabularRegressionPredictionResult', + 'TextExtractionPredictionResult', + 'TextSentimentPredictionResult', + 'VideoActionRecognitionPredictionResult', + 'VideoClassificationPredictionResult', + 'VideoObjectTrackingPredictionResult', +'ClassificationPredictionResult', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py index a0fd2058e0..019d5ea59c 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py @@ -15,26 +15,46 @@ # limitations under the License. # -from .classification import ClassificationPredictionResult -from .image_object_detection import ImageObjectDetectionPredictionResult -from .image_segmentation import ImageSegmentationPredictionResult -from .tabular_classification import TabularClassificationPredictionResult -from .tabular_regression import TabularRegressionPredictionResult -from .text_extraction import TextExtractionPredictionResult -from .text_sentiment import TextSentimentPredictionResult -from .video_action_recognition import VideoActionRecognitionPredictionResult -from .video_classification import VideoClassificationPredictionResult -from .video_object_tracking import VideoObjectTrackingPredictionResult +from .classification import ( + ClassificationPredictionResult, +) +from .image_object_detection import ( + ImageObjectDetectionPredictionResult, +) +from .image_segmentation import ( + ImageSegmentationPredictionResult, +) +from .tabular_classification import ( + TabularClassificationPredictionResult, +) +from .tabular_regression import ( + TabularRegressionPredictionResult, +) +from .text_extraction import ( + TextExtractionPredictionResult, +) +from .text_sentiment import ( + TextSentimentPredictionResult, +) +from .video_action_recognition import ( + VideoActionRecognitionPredictionResult, +) +from .video_classification import ( + VideoClassificationPredictionResult, +) +from .video_object_tracking import ( + VideoObjectTrackingPredictionResult, +) __all__ = ( - "ClassificationPredictionResult", - "ImageObjectDetectionPredictionResult", - "ImageSegmentationPredictionResult", - "TabularClassificationPredictionResult", - "TabularRegressionPredictionResult", - "TextExtractionPredictionResult", - "TextSentimentPredictionResult", - "VideoActionRecognitionPredictionResult", - "VideoClassificationPredictionResult", - "VideoObjectTrackingPredictionResult", + 'ClassificationPredictionResult', + 'ImageObjectDetectionPredictionResult', + 'ImageSegmentationPredictionResult', + 'TabularClassificationPredictionResult', + 'TabularRegressionPredictionResult', + 'TextExtractionPredictionResult', + 'TextSentimentPredictionResult', + 'VideoActionRecognitionPredictionResult', + 'VideoClassificationPredictionResult', + 'VideoObjectTrackingPredictionResult', ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py index cfc8e2e602..2ae1a3a9cf 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"ClassificationPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'ClassificationPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py index 31d37010db..2987851e58 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py @@ -22,8 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"ImageObjectDetectionPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'ImageObjectDetectionPredictionResult', + }, ) @@ -58,7 +60,9 @@ class ImageObjectDetectionPredictionResult(proto.Message): confidences = proto.RepeatedField(proto.FLOAT, number=3) - bboxes = proto.RepeatedField(proto.MESSAGE, number=4, message=struct.ListValue,) + bboxes = proto.RepeatedField(proto.MESSAGE, number=4, + message=struct.ListValue, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py index 1261f19723..c12b105a2f 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"ImageSegmentationPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'ImageSegmentationPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py index 7e78051467..6ffe672140 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"TabularClassificationPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'TabularClassificationPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py index c813f3e45c..f26cfa1b46 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"TabularRegressionPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'TabularRegressionPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py index 201f10d08a..05234d1324 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"TextExtractionPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'TextExtractionPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py index 73c670f4ec..27501ba0a6 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"TextSentimentPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'TextSentimentPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py index 486853c63d..ad88398dc6 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py @@ -23,8 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"VideoActionRecognitionPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'VideoActionRecognitionPredictionResult', + }, ) @@ -62,13 +64,17 @@ class VideoActionRecognitionPredictionResult(proto.Message): display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field( - proto.MESSAGE, number=4, message=duration.Duration, + time_segment_start = proto.Field(proto.MESSAGE, number=4, + message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) + time_segment_end = proto.Field(proto.MESSAGE, number=5, + message=duration.Duration, + ) - confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) + confidence = proto.Field(proto.MESSAGE, number=6, + message=wrappers.FloatValue, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py index c043547d04..12f042e10e 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py @@ -23,8 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"VideoClassificationPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'VideoClassificationPredictionResult', + }, ) @@ -78,13 +80,17 @@ class VideoClassificationPredictionResult(proto.Message): type_ = proto.Field(proto.STRING, number=3) - time_segment_start = proto.Field( - proto.MESSAGE, number=4, message=duration.Duration, + time_segment_start = proto.Field(proto.MESSAGE, number=4, + message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) + time_segment_end = proto.Field(proto.MESSAGE, number=5, + message=duration.Duration, + ) - confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) + confidence = proto.Field(proto.MESSAGE, number=6, + message=wrappers.FloatValue, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py index d1b515a895..672c039bc6 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py @@ -23,8 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.predict.prediction", - manifest={"VideoObjectTrackingPredictionResult",}, + package='google.cloud.aiplatform.v1.schema.predict.prediction', + manifest={ + 'VideoObjectTrackingPredictionResult', + }, ) @@ -62,7 +64,6 @@ class VideoObjectTrackingPredictionResult(proto.Message): bounding boxes in the frames identify the same object. """ - class Frame(proto.Message): r"""The fields ``xMin``, ``xMax``, ``yMin``, and ``yMax`` refer to a bounding box, i.e. the rectangle over the video frame pinpointing @@ -87,29 +88,45 @@ class Frame(proto.Message): box. """ - time_offset = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + time_offset = proto.Field(proto.MESSAGE, number=1, + message=duration.Duration, + ) - x_min = proto.Field(proto.MESSAGE, number=2, message=wrappers.FloatValue,) + x_min = proto.Field(proto.MESSAGE, number=2, + message=wrappers.FloatValue, + ) - x_max = proto.Field(proto.MESSAGE, number=3, message=wrappers.FloatValue,) + x_max = proto.Field(proto.MESSAGE, number=3, + message=wrappers.FloatValue, + ) - y_min = proto.Field(proto.MESSAGE, number=4, message=wrappers.FloatValue,) + y_min = proto.Field(proto.MESSAGE, number=4, + message=wrappers.FloatValue, + ) - y_max = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) + y_max = proto.Field(proto.MESSAGE, number=5, + message=wrappers.FloatValue, + ) id = proto.Field(proto.STRING, number=1) display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field( - proto.MESSAGE, number=3, message=duration.Duration, + time_segment_start = proto.Field(proto.MESSAGE, number=3, + message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=4, message=duration.Duration,) + time_segment_end = proto.Field(proto.MESSAGE, number=4, + message=duration.Duration, + ) - confidence = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) + confidence = proto.Field(proto.MESSAGE, number=5, + message=wrappers.FloatValue, + ) - frames = proto.RepeatedField(proto.MESSAGE, number=6, message=Frame,) + frames = proto.RepeatedField(proto.MESSAGE, number=6, + message=Frame, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py index f8620bb25d..1f57aea67f 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py @@ -15,106 +15,56 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( - AutoMlImageClassification, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( - AutoMlImageClassificationInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( - AutoMlImageClassificationMetadata, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( - AutoMlImageObjectDetection, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( - AutoMlImageObjectDetectionInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( - AutoMlImageObjectDetectionMetadata, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( - AutoMlImageSegmentation, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( - AutoMlImageSegmentationInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( - AutoMlImageSegmentationMetadata, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( - AutoMlTables, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( - AutoMlTablesInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( - AutoMlTablesMetadata, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import ( - AutoMlTextClassification, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import ( - AutoMlTextClassificationInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import ( - AutoMlTextExtraction, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import ( - AutoMlTextExtractionInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import ( - AutoMlTextSentiment, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import ( - AutoMlTextSentimentInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import ( - AutoMlVideoActionRecognition, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import ( - AutoMlVideoActionRecognitionInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import ( - AutoMlVideoClassification, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import ( - AutoMlVideoClassificationInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import ( - AutoMlVideoObjectTracking, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import ( - AutoMlVideoObjectTrackingInputs, -) -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.export_evaluated_data_items_config import ( - ExportEvaluatedDataItemsConfig, -) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import AutoMlImageClassification +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import AutoMlImageClassificationInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import AutoMlImageClassificationMetadata +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import AutoMlImageObjectDetection +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import AutoMlImageObjectDetectionInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import AutoMlImageObjectDetectionMetadata +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import AutoMlImageSegmentation +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import AutoMlImageSegmentationInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import AutoMlImageSegmentationMetadata +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import AutoMlTables +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import AutoMlTablesInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import AutoMlTablesMetadata +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import AutoMlTextClassification +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import AutoMlTextClassificationInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import AutoMlTextExtraction +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import AutoMlTextExtractionInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import AutoMlTextSentiment +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import AutoMlTextSentimentInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import AutoMlVideoActionRecognition +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import AutoMlVideoActionRecognitionInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import AutoMlVideoClassification +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import AutoMlVideoClassificationInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import AutoMlVideoObjectTracking +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import AutoMlVideoObjectTrackingInputs +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig __all__ = ( - "AutoMlImageClassification", - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", - "AutoMlTables", - "AutoMlTablesInputs", - "AutoMlTablesMetadata", - "AutoMlTextClassification", - "AutoMlTextClassificationInputs", - "AutoMlTextExtraction", - "AutoMlTextExtractionInputs", - "AutoMlTextSentiment", - "AutoMlTextSentimentInputs", - "AutoMlVideoActionRecognition", - "AutoMlVideoActionRecognitionInputs", - "AutoMlVideoClassification", - "AutoMlVideoClassificationInputs", - "AutoMlVideoObjectTracking", - "AutoMlVideoObjectTrackingInputs", - "ExportEvaluatedDataItemsConfig", + 'AutoMlImageClassification', + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + 'ExportEvaluatedDataItemsConfig', ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py index 34958e5add..135e04f228 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py @@ -43,29 +43,29 @@ __all__ = ( - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", - "AutoMlTables", - "AutoMlTablesInputs", - "AutoMlTablesMetadata", - "AutoMlTextClassification", - "AutoMlTextClassificationInputs", - "AutoMlTextExtraction", - "AutoMlTextExtractionInputs", - "AutoMlTextSentiment", - "AutoMlTextSentimentInputs", - "AutoMlVideoActionRecognition", - "AutoMlVideoActionRecognitionInputs", - "AutoMlVideoClassification", - "AutoMlVideoClassificationInputs", - "AutoMlVideoObjectTracking", - "AutoMlVideoObjectTrackingInputs", - "ExportEvaluatedDataItemsConfig", - "AutoMlImageClassification", + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + 'ExportEvaluatedDataItemsConfig', +'AutoMlImageClassification', ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py index a15aa2c041..2d7d19c057 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py @@ -59,32 +59,34 @@ AutoMlVideoObjectTracking, AutoMlVideoObjectTrackingInputs, ) -from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig +from .export_evaluated_data_items_config import ( + ExportEvaluatedDataItemsConfig, +) __all__ = ( - "AutoMlImageClassification", - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", - "AutoMlTables", - "AutoMlTablesInputs", - "AutoMlTablesMetadata", - "AutoMlTextClassification", - "AutoMlTextClassificationInputs", - "AutoMlTextExtraction", - "AutoMlTextExtractionInputs", - "AutoMlTextSentiment", - "AutoMlTextSentimentInputs", - "AutoMlVideoActionRecognition", - "AutoMlVideoActionRecognitionInputs", - "AutoMlVideoClassification", - "AutoMlVideoClassificationInputs", - "AutoMlVideoObjectTracking", - "AutoMlVideoObjectTrackingInputs", - "ExportEvaluatedDataItemsConfig", + 'AutoMlImageClassification', + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + 'ExportEvaluatedDataItemsConfig', ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py index f7e13c60b7..530007c977 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', manifest={ - "AutoMlImageClassification", - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", + 'AutoMlImageClassification', + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', }, ) @@ -39,12 +39,12 @@ class AutoMlImageClassification(proto.Message): The metadata information. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlImageClassificationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlImageClassificationInputs', ) - metadata = proto.Field( - proto.MESSAGE, number=2, message="AutoMlImageClassificationMetadata", + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlImageClassificationMetadata', ) @@ -92,7 +92,6 @@ class AutoMlImageClassificationInputs(proto.Message): be trained (i.e. assuming that for each image multiple annotations may be applicable). """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -101,7 +100,9 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 3 MOBILE_TF_HIGH_ACCURACY_1 = 4 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) base_model_id = proto.Field(proto.STRING, number=2) @@ -126,7 +127,6 @@ class AutoMlImageClassificationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ - class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -135,8 +135,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field( - proto.ENUM, number=2, enum=SuccessfulStopReason, + successful_stop_reason = proto.Field(proto.ENUM, number=2, + enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py index 1c2c9f83b7..9aa8ea5b3d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', manifest={ - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', }, ) @@ -39,12 +39,12 @@ class AutoMlImageObjectDetection(proto.Message): The metadata information """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlImageObjectDetectionInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlImageObjectDetectionInputs', ) - metadata = proto.Field( - proto.MESSAGE, number=2, message="AutoMlImageObjectDetectionMetadata", + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlImageObjectDetectionMetadata', ) @@ -80,7 +80,6 @@ class AutoMlImageObjectDetectionInputs(proto.Message): training before the entire training budget has been used. """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -90,7 +89,9 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 4 MOBILE_TF_HIGH_ACCURACY_1 = 5 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -111,7 +112,6 @@ class AutoMlImageObjectDetectionMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ - class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -120,8 +120,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field( - proto.ENUM, number=2, enum=SuccessfulStopReason, + successful_stop_reason = proto.Field(proto.ENUM, number=2, + enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py index a81103657e..9188939a09 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', manifest={ - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', }, ) @@ -39,12 +39,12 @@ class AutoMlImageSegmentation(proto.Message): The metadata information. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlImageSegmentationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlImageSegmentationInputs', ) - metadata = proto.Field( - proto.MESSAGE, number=2, message="AutoMlImageSegmentationMetadata", + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlImageSegmentationMetadata', ) @@ -76,7 +76,6 @@ class AutoMlImageSegmentationInputs(proto.Message): ``base`` model must be in the same Project and Location as the new Model to train, and have the same modelType. """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -84,7 +83,9 @@ class ModelType(proto.Enum): CLOUD_LOW_ACCURACY_1 = 2 MOBILE_TF_LOW_LATENCY_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -105,7 +106,6 @@ class AutoMlImageSegmentationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ - class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -114,8 +114,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field( - proto.ENUM, number=2, enum=SuccessfulStopReason, + successful_stop_reason = proto.Field(proto.ENUM, number=2, + enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py index 1c3d0c8da7..1efe804ca5 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py @@ -18,14 +18,16 @@ import proto # type: ignore -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types import ( - export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config, -) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types import export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + }, ) @@ -39,9 +41,13 @@ class AutoMlTables(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTablesInputs",) + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTablesInputs', + ) - metadata = proto.Field(proto.MESSAGE, number=2, message="AutoMlTablesMetadata",) + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlTablesMetadata', + ) class AutoMlTablesInputs(proto.Message): @@ -146,7 +152,6 @@ class AutoMlTablesInputs(proto.Message): configuration is absent, then the export is not performed. """ - class Transformation(proto.Message): r""" @@ -168,7 +173,6 @@ class Transformation(proto.Message): repeated_text (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation): """ - class AutoTransformation(proto.Message): r"""Training pipeline will infer the proper transformation based on the statistic of dataset. @@ -343,76 +347,48 @@ class TextArrayTransformation(proto.Message): column_name = proto.Field(proto.STRING, number=1) - auto = proto.Field( - proto.MESSAGE, - number=1, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.AutoTransformation", + auto = proto.Field(proto.MESSAGE, number=1, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.AutoTransformation', ) - numeric = proto.Field( - proto.MESSAGE, - number=2, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.NumericTransformation", + numeric = proto.Field(proto.MESSAGE, number=2, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.NumericTransformation', ) - categorical = proto.Field( - proto.MESSAGE, - number=3, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.CategoricalTransformation", + categorical = proto.Field(proto.MESSAGE, number=3, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.CategoricalTransformation', ) - timestamp = proto.Field( - proto.MESSAGE, - number=4, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.TimestampTransformation", + timestamp = proto.Field(proto.MESSAGE, number=4, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.TimestampTransformation', ) - text = proto.Field( - proto.MESSAGE, - number=5, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.TextTransformation", + text = proto.Field(proto.MESSAGE, number=5, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.TextTransformation', ) - repeated_numeric = proto.Field( - proto.MESSAGE, - number=6, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.NumericArrayTransformation", + repeated_numeric = proto.Field(proto.MESSAGE, number=6, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.NumericArrayTransformation', ) - repeated_categorical = proto.Field( - proto.MESSAGE, - number=7, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.CategoricalArrayTransformation", + repeated_categorical = proto.Field(proto.MESSAGE, number=7, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.CategoricalArrayTransformation', ) - repeated_text = proto.Field( - proto.MESSAGE, - number=8, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.TextArrayTransformation", + repeated_text = proto.Field(proto.MESSAGE, number=8, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.TextArrayTransformation', ) - optimization_objective_recall_value = proto.Field( - proto.FLOAT, number=5, oneof="additional_optimization_objective_config" - ) + optimization_objective_recall_value = proto.Field(proto.FLOAT, number=5, oneof='additional_optimization_objective_config') - optimization_objective_precision_value = proto.Field( - proto.FLOAT, number=6, oneof="additional_optimization_objective_config" - ) + optimization_objective_precision_value = proto.Field(proto.FLOAT, number=6, oneof='additional_optimization_objective_config') prediction_type = proto.Field(proto.STRING, number=1) target_column = proto.Field(proto.STRING, number=2) - transformations = proto.RepeatedField( - proto.MESSAGE, number=3, message=Transformation, + transformations = proto.RepeatedField(proto.MESSAGE, number=3, + message=Transformation, ) optimization_objective = proto.Field(proto.STRING, number=4) @@ -423,9 +399,7 @@ class TextArrayTransformation(proto.Message): weight_column_name = proto.Field(proto.STRING, number=9) - export_evaluated_data_items_config = proto.Field( - proto.MESSAGE, - number=10, + export_evaluated_data_items_config = proto.Field(proto.MESSAGE, number=10, message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py index 205deaf375..adcd3a46fb 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"AutoMlTextClassification", "AutoMlTextClassificationInputs",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlTextClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlTextClassificationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTextClassificationInputs', ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py index fad28847af..f6d6064504 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"AutoMlTextExtraction", "AutoMlTextExtractionInputs",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + }, ) @@ -33,7 +36,9 @@ class AutoMlTextExtraction(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextExtractionInputs",) + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTextExtractionInputs', + ) class AutoMlTextExtractionInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py index ca80a44d1d..5d67713e3d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"AutoMlTextSentiment", "AutoMlTextSentimentInputs",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + }, ) @@ -33,7 +36,9 @@ class AutoMlTextSentiment(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextSentimentInputs",) + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTextSentimentInputs', + ) class AutoMlTextSentimentInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py index 1a20a6d725..06653758a7 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"AutoMlVideoActionRecognition", "AutoMlVideoActionRecognitionInputs",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlVideoActionRecognition(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlVideoActionRecognitionInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlVideoActionRecognitionInputs', ) @@ -45,14 +48,15 @@ class AutoMlVideoActionRecognitionInputs(proto.Message): model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoActionRecognitionInputs.ModelType): """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 CLOUD = 1 MOBILE_VERSATILE_1 = 2 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py index ba7f2d5b21..486e4d0ecb 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"AutoMlVideoClassification", "AutoMlVideoClassificationInputs",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlVideoClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlVideoClassificationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlVideoClassificationInputs', ) @@ -45,7 +48,6 @@ class AutoMlVideoClassificationInputs(proto.Message): model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoClassificationInputs.ModelType): """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -53,7 +55,9 @@ class ModelType(proto.Enum): MOBILE_VERSATILE_1 = 2 MOBILE_JETSON_VERSATILE_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py index 0ecb1113d9..de660f7d1d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlVideoObjectTracking(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlVideoObjectTrackingInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlVideoObjectTrackingInputs', ) @@ -45,7 +48,6 @@ class AutoMlVideoObjectTrackingInputs(proto.Message): model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoObjectTrackingInputs.ModelType): """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -56,7 +58,9 @@ class ModelType(proto.Enum): MOBILE_JETSON_VERSATILE_1 = 5 MOBILE_JETSON_LOW_LATENCY_1 = 6 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py index dc8a629412..a5b1fcb542 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1.schema.trainingjob.definition", - manifest={"ExportEvaluatedDataItemsConfig",}, + package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + manifest={ + 'ExportEvaluatedDataItemsConfig', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py index 2f514ac4ed..62c5942a51 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py @@ -15,42 +15,24 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_classification import ( - ImageClassificationPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_object_detection import ( - ImageObjectDetectionPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_segmentation import ( - ImageSegmentationPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_classification import ( - TextClassificationPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_extraction import ( - TextExtractionPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_sentiment import ( - TextSentimentPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_action_recognition import ( - VideoActionRecognitionPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_classification import ( - VideoClassificationPredictionInstance, -) -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_object_tracking import ( - VideoObjectTrackingPredictionInstance, -) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_classification import ImageClassificationPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_object_detection import ImageObjectDetectionPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_segmentation import ImageSegmentationPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_classification import TextClassificationPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_extraction import TextExtractionPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_sentiment import TextSentimentPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_action_recognition import VideoActionRecognitionPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_classification import VideoClassificationPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_object_tracking import VideoObjectTrackingPredictionInstance __all__ = ( - "ImageClassificationPredictionInstance", - "ImageObjectDetectionPredictionInstance", - "ImageSegmentationPredictionInstance", - "TextClassificationPredictionInstance", - "TextExtractionPredictionInstance", - "TextSentimentPredictionInstance", - "VideoActionRecognitionPredictionInstance", - "VideoClassificationPredictionInstance", - "VideoObjectTrackingPredictionInstance", + 'ImageClassificationPredictionInstance', + 'ImageObjectDetectionPredictionInstance', + 'ImageSegmentationPredictionInstance', + 'TextClassificationPredictionInstance', + 'TextExtractionPredictionInstance', + 'TextSentimentPredictionInstance', + 'VideoActionRecognitionPredictionInstance', + 'VideoClassificationPredictionInstance', + 'VideoObjectTrackingPredictionInstance', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py index f6d9a128ad..c68b05e778 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py @@ -27,13 +27,13 @@ __all__ = ( - "ImageObjectDetectionPredictionInstance", - "ImageSegmentationPredictionInstance", - "TextClassificationPredictionInstance", - "TextExtractionPredictionInstance", - "TextSentimentPredictionInstance", - "VideoActionRecognitionPredictionInstance", - "VideoClassificationPredictionInstance", - "VideoObjectTrackingPredictionInstance", - "ImageClassificationPredictionInstance", + 'ImageObjectDetectionPredictionInstance', + 'ImageSegmentationPredictionInstance', + 'TextClassificationPredictionInstance', + 'TextExtractionPredictionInstance', + 'TextSentimentPredictionInstance', + 'VideoActionRecognitionPredictionInstance', + 'VideoClassificationPredictionInstance', + 'VideoObjectTrackingPredictionInstance', +'ImageClassificationPredictionInstance', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py index 041fe6cdb1..aacf581e2e 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py @@ -15,24 +15,42 @@ # limitations under the License. # -from .image_classification import ImageClassificationPredictionInstance -from .image_object_detection import ImageObjectDetectionPredictionInstance -from .image_segmentation import ImageSegmentationPredictionInstance -from .text_classification import TextClassificationPredictionInstance -from .text_extraction import TextExtractionPredictionInstance -from .text_sentiment import TextSentimentPredictionInstance -from .video_action_recognition import VideoActionRecognitionPredictionInstance -from .video_classification import VideoClassificationPredictionInstance -from .video_object_tracking import VideoObjectTrackingPredictionInstance +from .image_classification import ( + ImageClassificationPredictionInstance, +) +from .image_object_detection import ( + ImageObjectDetectionPredictionInstance, +) +from .image_segmentation import ( + ImageSegmentationPredictionInstance, +) +from .text_classification import ( + TextClassificationPredictionInstance, +) +from .text_extraction import ( + TextExtractionPredictionInstance, +) +from .text_sentiment import ( + TextSentimentPredictionInstance, +) +from .video_action_recognition import ( + VideoActionRecognitionPredictionInstance, +) +from .video_classification import ( + VideoClassificationPredictionInstance, +) +from .video_object_tracking import ( + VideoObjectTrackingPredictionInstance, +) __all__ = ( - "ImageClassificationPredictionInstance", - "ImageObjectDetectionPredictionInstance", - "ImageSegmentationPredictionInstance", - "TextClassificationPredictionInstance", - "TextExtractionPredictionInstance", - "TextSentimentPredictionInstance", - "VideoActionRecognitionPredictionInstance", - "VideoClassificationPredictionInstance", - "VideoObjectTrackingPredictionInstance", + 'ImageClassificationPredictionInstance', + 'ImageObjectDetectionPredictionInstance', + 'ImageSegmentationPredictionInstance', + 'TextClassificationPredictionInstance', + 'TextExtractionPredictionInstance', + 'TextSentimentPredictionInstance', + 'VideoActionRecognitionPredictionInstance', + 'VideoClassificationPredictionInstance', + 'VideoObjectTrackingPredictionInstance', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py index 84b1ef0bbe..c0a0d477a4 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"ImageClassificationPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'ImageClassificationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py index 79c3efc2c6..32cdc492ad 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"ImageObjectDetectionPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'ImageObjectDetectionPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py index 5a3232c6d2..0e1d5293ea 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"ImageSegmentationPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'ImageSegmentationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py index a615dc7e49..3ea5a96d5d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"TextClassificationPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'TextClassificationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py index c6fecf80b7..d256b7d008 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"TextExtractionPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'TextExtractionPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py index 69836d0e96..0e0a339a1c 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"TextSentimentPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'TextSentimentPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py index ae3935d387..14a4e4ffec 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"VideoActionRecognitionPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'VideoActionRecognitionPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py index 2f944bb99e..77e8d9e1c0 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"VideoClassificationPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'VideoClassificationPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py index e635b5174b..ab4b3f282f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.instance", - manifest={"VideoObjectTrackingPredictionInstance",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.instance', + manifest={ + 'VideoObjectTrackingPredictionInstance', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py index dc7cd58e9a..0de177503e 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py @@ -15,30 +15,18 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_classification import ( - ImageClassificationPredictionParams, -) -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_object_detection import ( - ImageObjectDetectionPredictionParams, -) -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_segmentation import ( - ImageSegmentationPredictionParams, -) -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_action_recognition import ( - VideoActionRecognitionPredictionParams, -) -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_classification import ( - VideoClassificationPredictionParams, -) -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_object_tracking import ( - VideoObjectTrackingPredictionParams, -) +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_classification import ImageClassificationPredictionParams +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_object_detection import ImageObjectDetectionPredictionParams +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_segmentation import ImageSegmentationPredictionParams +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_action_recognition import VideoActionRecognitionPredictionParams +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_classification import VideoClassificationPredictionParams +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_object_tracking import VideoObjectTrackingPredictionParams __all__ = ( - "ImageClassificationPredictionParams", - "ImageObjectDetectionPredictionParams", - "ImageSegmentationPredictionParams", - "VideoActionRecognitionPredictionParams", - "VideoClassificationPredictionParams", - "VideoObjectTrackingPredictionParams", + 'ImageClassificationPredictionParams', + 'ImageObjectDetectionPredictionParams', + 'ImageSegmentationPredictionParams', + 'VideoActionRecognitionPredictionParams', + 'VideoClassificationPredictionParams', + 'VideoObjectTrackingPredictionParams', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py index 79fb1c2097..0e358981b3 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py @@ -24,10 +24,10 @@ __all__ = ( - "ImageObjectDetectionPredictionParams", - "ImageSegmentationPredictionParams", - "VideoActionRecognitionPredictionParams", - "VideoClassificationPredictionParams", - "VideoObjectTrackingPredictionParams", - "ImageClassificationPredictionParams", + 'ImageObjectDetectionPredictionParams', + 'ImageSegmentationPredictionParams', + 'VideoActionRecognitionPredictionParams', + 'VideoClassificationPredictionParams', + 'VideoObjectTrackingPredictionParams', +'ImageClassificationPredictionParams', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py index 2f2c29bba5..4f53fda062 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py @@ -15,18 +15,30 @@ # limitations under the License. # -from .image_classification import ImageClassificationPredictionParams -from .image_object_detection import ImageObjectDetectionPredictionParams -from .image_segmentation import ImageSegmentationPredictionParams -from .video_action_recognition import VideoActionRecognitionPredictionParams -from .video_classification import VideoClassificationPredictionParams -from .video_object_tracking import VideoObjectTrackingPredictionParams +from .image_classification import ( + ImageClassificationPredictionParams, +) +from .image_object_detection import ( + ImageObjectDetectionPredictionParams, +) +from .image_segmentation import ( + ImageSegmentationPredictionParams, +) +from .video_action_recognition import ( + VideoActionRecognitionPredictionParams, +) +from .video_classification import ( + VideoClassificationPredictionParams, +) +from .video_object_tracking import ( + VideoObjectTrackingPredictionParams, +) __all__ = ( - "ImageClassificationPredictionParams", - "ImageObjectDetectionPredictionParams", - "ImageSegmentationPredictionParams", - "VideoActionRecognitionPredictionParams", - "VideoClassificationPredictionParams", - "VideoObjectTrackingPredictionParams", + 'ImageClassificationPredictionParams', + 'ImageObjectDetectionPredictionParams', + 'ImageSegmentationPredictionParams', + 'VideoActionRecognitionPredictionParams', + 'VideoClassificationPredictionParams', + 'VideoObjectTrackingPredictionParams', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py index 681a8c3d87..1bfe57e1e6 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.params", - manifest={"ImageClassificationPredictionParams",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.params', + manifest={ + 'ImageClassificationPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py index 146dd324b7..ba86d17656 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.params", - manifest={"ImageObjectDetectionPredictionParams",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.params', + manifest={ + 'ImageObjectDetectionPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py index aa11739a61..ab5b028025 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.params", - manifest={"ImageSegmentationPredictionParams",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.params', + manifest={ + 'ImageSegmentationPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py index c1f8f9f3bc..60b9bee8c8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.params", - manifest={"VideoActionRecognitionPredictionParams",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.params', + manifest={ + 'VideoActionRecognitionPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py index 1b8d84a7d1..f90d338919 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.params", - manifest={"VideoClassificationPredictionParams",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.params', + manifest={ + 'VideoClassificationPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py index 4c0b6846bc..7c92def8fc 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.params", - manifest={"VideoObjectTrackingPredictionParams",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.params', + manifest={ + 'VideoObjectTrackingPredictionParams', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py index d5f2762504..5041ec8e6f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py @@ -15,46 +15,26 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.classification import ( - ClassificationPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_object_detection import ( - ImageObjectDetectionPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_segmentation import ( - ImageSegmentationPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_classification import ( - TabularClassificationPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_regression import ( - TabularRegressionPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_extraction import ( - TextExtractionPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_sentiment import ( - TextSentimentPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_action_recognition import ( - VideoActionRecognitionPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_classification import ( - VideoClassificationPredictionResult, -) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_object_tracking import ( - VideoObjectTrackingPredictionResult, -) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.classification import ClassificationPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_object_detection import ImageObjectDetectionPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_segmentation import ImageSegmentationPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_classification import TabularClassificationPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_regression import TabularRegressionPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_extraction import TextExtractionPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_sentiment import TextSentimentPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_action_recognition import VideoActionRecognitionPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_classification import VideoClassificationPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_object_tracking import VideoObjectTrackingPredictionResult __all__ = ( - "ClassificationPredictionResult", - "ImageObjectDetectionPredictionResult", - "ImageSegmentationPredictionResult", - "TabularClassificationPredictionResult", - "TabularRegressionPredictionResult", - "TextExtractionPredictionResult", - "TextSentimentPredictionResult", - "VideoActionRecognitionPredictionResult", - "VideoClassificationPredictionResult", - "VideoObjectTrackingPredictionResult", + 'ClassificationPredictionResult', + 'ImageObjectDetectionPredictionResult', + 'ImageSegmentationPredictionResult', + 'TabularClassificationPredictionResult', + 'TabularRegressionPredictionResult', + 'TextExtractionPredictionResult', + 'TextSentimentPredictionResult', + 'VideoActionRecognitionPredictionResult', + 'VideoClassificationPredictionResult', + 'VideoObjectTrackingPredictionResult', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py index 91fae5a3b1..42f26f575f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py @@ -28,14 +28,14 @@ __all__ = ( - "ImageObjectDetectionPredictionResult", - "ImageSegmentationPredictionResult", - "TabularClassificationPredictionResult", - "TabularRegressionPredictionResult", - "TextExtractionPredictionResult", - "TextSentimentPredictionResult", - "VideoActionRecognitionPredictionResult", - "VideoClassificationPredictionResult", - "VideoObjectTrackingPredictionResult", - "ClassificationPredictionResult", + 'ImageObjectDetectionPredictionResult', + 'ImageSegmentationPredictionResult', + 'TabularClassificationPredictionResult', + 'TabularRegressionPredictionResult', + 'TextExtractionPredictionResult', + 'TextSentimentPredictionResult', + 'VideoActionRecognitionPredictionResult', + 'VideoClassificationPredictionResult', + 'VideoObjectTrackingPredictionResult', +'ClassificationPredictionResult', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py index a0fd2058e0..019d5ea59c 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py @@ -15,26 +15,46 @@ # limitations under the License. # -from .classification import ClassificationPredictionResult -from .image_object_detection import ImageObjectDetectionPredictionResult -from .image_segmentation import ImageSegmentationPredictionResult -from .tabular_classification import TabularClassificationPredictionResult -from .tabular_regression import TabularRegressionPredictionResult -from .text_extraction import TextExtractionPredictionResult -from .text_sentiment import TextSentimentPredictionResult -from .video_action_recognition import VideoActionRecognitionPredictionResult -from .video_classification import VideoClassificationPredictionResult -from .video_object_tracking import VideoObjectTrackingPredictionResult +from .classification import ( + ClassificationPredictionResult, +) +from .image_object_detection import ( + ImageObjectDetectionPredictionResult, +) +from .image_segmentation import ( + ImageSegmentationPredictionResult, +) +from .tabular_classification import ( + TabularClassificationPredictionResult, +) +from .tabular_regression import ( + TabularRegressionPredictionResult, +) +from .text_extraction import ( + TextExtractionPredictionResult, +) +from .text_sentiment import ( + TextSentimentPredictionResult, +) +from .video_action_recognition import ( + VideoActionRecognitionPredictionResult, +) +from .video_classification import ( + VideoClassificationPredictionResult, +) +from .video_object_tracking import ( + VideoObjectTrackingPredictionResult, +) __all__ = ( - "ClassificationPredictionResult", - "ImageObjectDetectionPredictionResult", - "ImageSegmentationPredictionResult", - "TabularClassificationPredictionResult", - "TabularRegressionPredictionResult", - "TextExtractionPredictionResult", - "TextSentimentPredictionResult", - "VideoActionRecognitionPredictionResult", - "VideoClassificationPredictionResult", - "VideoObjectTrackingPredictionResult", + 'ClassificationPredictionResult', + 'ImageObjectDetectionPredictionResult', + 'ImageSegmentationPredictionResult', + 'TabularClassificationPredictionResult', + 'TabularRegressionPredictionResult', + 'TextExtractionPredictionResult', + 'TextSentimentPredictionResult', + 'VideoActionRecognitionPredictionResult', + 'VideoClassificationPredictionResult', + 'VideoObjectTrackingPredictionResult', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py index 3bfe82f64e..ed4bcece4f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"ClassificationPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'ClassificationPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py index 3d0f7f1f76..f125a9d4a6 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py @@ -22,8 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"ImageObjectDetectionPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'ImageObjectDetectionPredictionResult', + }, ) @@ -58,7 +60,9 @@ class ImageObjectDetectionPredictionResult(proto.Message): confidences = proto.RepeatedField(proto.FLOAT, number=3) - bboxes = proto.RepeatedField(proto.MESSAGE, number=4, message=struct.ListValue,) + bboxes = proto.RepeatedField(proto.MESSAGE, number=4, + message=struct.ListValue, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py index ffd6fb9380..abc5977b79 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"ImageSegmentationPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'ImageSegmentationPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py index 4906ad59a5..bd373e8e8d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"TabularClassificationPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'TabularClassificationPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py index 71d535c1f0..bc21aaaf8d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"TabularRegressionPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'TabularRegressionPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py index e3c10b5d75..e23faf278f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"TextExtractionPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'TextExtractionPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py index f31b95a18f..9a822e7782 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"TextSentimentPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'TextSentimentPredictionResult', + }, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py index 99fa365b47..6b70a6c36c 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py @@ -23,8 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"VideoActionRecognitionPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'VideoActionRecognitionPredictionResult', + }, ) @@ -62,13 +64,17 @@ class VideoActionRecognitionPredictionResult(proto.Message): display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field( - proto.MESSAGE, number=4, message=duration.Duration, + time_segment_start = proto.Field(proto.MESSAGE, number=4, + message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) + time_segment_end = proto.Field(proto.MESSAGE, number=5, + message=duration.Duration, + ) - confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) + confidence = proto.Field(proto.MESSAGE, number=6, + message=wrappers.FloatValue, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py index 3fca68fe64..2b435bbff8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py @@ -23,8 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"VideoClassificationPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'VideoClassificationPredictionResult', + }, ) @@ -78,13 +80,17 @@ class VideoClassificationPredictionResult(proto.Message): type_ = proto.Field(proto.STRING, number=3) - time_segment_start = proto.Field( - proto.MESSAGE, number=4, message=duration.Duration, + time_segment_start = proto.Field(proto.MESSAGE, number=4, + message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) + time_segment_end = proto.Field(proto.MESSAGE, number=5, + message=duration.Duration, + ) - confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) + confidence = proto.Field(proto.MESSAGE, number=6, + message=wrappers.FloatValue, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py index 6fd431c0dd..2bbf98710c 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py @@ -23,8 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", - manifest={"VideoObjectTrackingPredictionResult",}, + package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', + manifest={ + 'VideoObjectTrackingPredictionResult', + }, ) @@ -62,7 +64,6 @@ class VideoObjectTrackingPredictionResult(proto.Message): bounding boxes in the frames identify the same object. """ - class Frame(proto.Message): r"""The fields ``xMin``, ``xMax``, ``yMin``, and ``yMax`` refer to a bounding box, i.e. the rectangle over the video frame pinpointing @@ -87,29 +88,45 @@ class Frame(proto.Message): box. """ - time_offset = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + time_offset = proto.Field(proto.MESSAGE, number=1, + message=duration.Duration, + ) - x_min = proto.Field(proto.MESSAGE, number=2, message=wrappers.FloatValue,) + x_min = proto.Field(proto.MESSAGE, number=2, + message=wrappers.FloatValue, + ) - x_max = proto.Field(proto.MESSAGE, number=3, message=wrappers.FloatValue,) + x_max = proto.Field(proto.MESSAGE, number=3, + message=wrappers.FloatValue, + ) - y_min = proto.Field(proto.MESSAGE, number=4, message=wrappers.FloatValue,) + y_min = proto.Field(proto.MESSAGE, number=4, + message=wrappers.FloatValue, + ) - y_max = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) + y_max = proto.Field(proto.MESSAGE, number=5, + message=wrappers.FloatValue, + ) id = proto.Field(proto.STRING, number=1) display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field( - proto.MESSAGE, number=3, message=duration.Duration, + time_segment_start = proto.Field(proto.MESSAGE, number=3, + message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=4, message=duration.Duration,) + time_segment_end = proto.Field(proto.MESSAGE, number=4, + message=duration.Duration, + ) - confidence = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) + confidence = proto.Field(proto.MESSAGE, number=5, + message=wrappers.FloatValue, + ) - frames = proto.RepeatedField(proto.MESSAGE, number=6, message=Frame,) + frames = proto.RepeatedField(proto.MESSAGE, number=6, + message=Frame, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py index d632ef9609..9475d2c67c 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py @@ -15,106 +15,56 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import ( - AutoMlImageClassification, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import ( - AutoMlImageClassificationInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import ( - AutoMlImageClassificationMetadata, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import ( - AutoMlImageObjectDetection, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import ( - AutoMlImageObjectDetectionInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import ( - AutoMlImageObjectDetectionMetadata, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import ( - AutoMlImageSegmentation, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import ( - AutoMlImageSegmentationInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import ( - AutoMlImageSegmentationMetadata, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import ( - AutoMlTables, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import ( - AutoMlTablesInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import ( - AutoMlTablesMetadata, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import ( - AutoMlTextClassification, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import ( - AutoMlTextClassificationInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import ( - AutoMlTextExtraction, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import ( - AutoMlTextExtractionInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import ( - AutoMlTextSentiment, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import ( - AutoMlTextSentimentInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import ( - AutoMlVideoActionRecognition, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import ( - AutoMlVideoActionRecognitionInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import ( - AutoMlVideoClassification, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import ( - AutoMlVideoClassificationInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import ( - AutoMlVideoObjectTracking, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import ( - AutoMlVideoObjectTrackingInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.export_evaluated_data_items_config import ( - ExportEvaluatedDataItemsConfig, -) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import AutoMlImageClassification +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import AutoMlImageClassificationInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import AutoMlImageClassificationMetadata +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import AutoMlImageObjectDetection +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import AutoMlImageObjectDetectionInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import AutoMlImageObjectDetectionMetadata +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import AutoMlImageSegmentation +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import AutoMlImageSegmentationInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import AutoMlImageSegmentationMetadata +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import AutoMlTables +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import AutoMlTablesInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import AutoMlTablesMetadata +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import AutoMlTextClassification +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import AutoMlTextClassificationInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import AutoMlTextExtraction +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import AutoMlTextExtractionInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import AutoMlTextSentiment +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import AutoMlTextSentimentInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import AutoMlVideoActionRecognition +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import AutoMlVideoActionRecognitionInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import AutoMlVideoClassification +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import AutoMlVideoClassificationInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import AutoMlVideoObjectTracking +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import AutoMlVideoObjectTrackingInputs +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig __all__ = ( - "AutoMlImageClassification", - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", - "AutoMlTables", - "AutoMlTablesInputs", - "AutoMlTablesMetadata", - "AutoMlTextClassification", - "AutoMlTextClassificationInputs", - "AutoMlTextExtraction", - "AutoMlTextExtractionInputs", - "AutoMlTextSentiment", - "AutoMlTextSentimentInputs", - "AutoMlVideoActionRecognition", - "AutoMlVideoActionRecognitionInputs", - "AutoMlVideoClassification", - "AutoMlVideoClassificationInputs", - "AutoMlVideoObjectTracking", - "AutoMlVideoObjectTrackingInputs", - "ExportEvaluatedDataItemsConfig", + 'AutoMlImageClassification', + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + 'ExportEvaluatedDataItemsConfig', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py index 34958e5add..135e04f228 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py @@ -43,29 +43,29 @@ __all__ = ( - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", - "AutoMlTables", - "AutoMlTablesInputs", - "AutoMlTablesMetadata", - "AutoMlTextClassification", - "AutoMlTextClassificationInputs", - "AutoMlTextExtraction", - "AutoMlTextExtractionInputs", - "AutoMlTextSentiment", - "AutoMlTextSentimentInputs", - "AutoMlVideoActionRecognition", - "AutoMlVideoActionRecognitionInputs", - "AutoMlVideoClassification", - "AutoMlVideoClassificationInputs", - "AutoMlVideoObjectTracking", - "AutoMlVideoObjectTrackingInputs", - "ExportEvaluatedDataItemsConfig", - "AutoMlImageClassification", + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + 'ExportEvaluatedDataItemsConfig', +'AutoMlImageClassification', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py index a15aa2c041..2d7d19c057 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py @@ -59,32 +59,34 @@ AutoMlVideoObjectTracking, AutoMlVideoObjectTrackingInputs, ) -from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig +from .export_evaluated_data_items_config import ( + ExportEvaluatedDataItemsConfig, +) __all__ = ( - "AutoMlImageClassification", - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", - "AutoMlTables", - "AutoMlTablesInputs", - "AutoMlTablesMetadata", - "AutoMlTextClassification", - "AutoMlTextClassificationInputs", - "AutoMlTextExtraction", - "AutoMlTextExtractionInputs", - "AutoMlTextSentiment", - "AutoMlTextSentimentInputs", - "AutoMlVideoActionRecognition", - "AutoMlVideoActionRecognitionInputs", - "AutoMlVideoClassification", - "AutoMlVideoClassificationInputs", - "AutoMlVideoObjectTracking", - "AutoMlVideoObjectTrackingInputs", - "ExportEvaluatedDataItemsConfig", + 'AutoMlImageClassification', + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + 'ExportEvaluatedDataItemsConfig', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py index 8ee27076d2..6eb4ada23e 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', manifest={ - "AutoMlImageClassification", - "AutoMlImageClassificationInputs", - "AutoMlImageClassificationMetadata", + 'AutoMlImageClassification', + 'AutoMlImageClassificationInputs', + 'AutoMlImageClassificationMetadata', }, ) @@ -39,12 +39,12 @@ class AutoMlImageClassification(proto.Message): The metadata information. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlImageClassificationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlImageClassificationInputs', ) - metadata = proto.Field( - proto.MESSAGE, number=2, message="AutoMlImageClassificationMetadata", + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlImageClassificationMetadata', ) @@ -92,7 +92,6 @@ class AutoMlImageClassificationInputs(proto.Message): be trained (i.e. assuming that for each image multiple annotations may be applicable). """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -101,7 +100,9 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 3 MOBILE_TF_HIGH_ACCURACY_1 = 4 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) base_model_id = proto.Field(proto.STRING, number=2) @@ -126,7 +127,6 @@ class AutoMlImageClassificationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ - class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -135,8 +135,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field( - proto.ENUM, number=2, enum=SuccessfulStopReason, + successful_stop_reason = proto.Field(proto.ENUM, number=2, + enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py index 512e35ed1d..6cd9a9684d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', manifest={ - "AutoMlImageObjectDetection", - "AutoMlImageObjectDetectionInputs", - "AutoMlImageObjectDetectionMetadata", + 'AutoMlImageObjectDetection', + 'AutoMlImageObjectDetectionInputs', + 'AutoMlImageObjectDetectionMetadata', }, ) @@ -39,12 +39,12 @@ class AutoMlImageObjectDetection(proto.Message): The metadata information """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlImageObjectDetectionInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlImageObjectDetectionInputs', ) - metadata = proto.Field( - proto.MESSAGE, number=2, message="AutoMlImageObjectDetectionMetadata", + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlImageObjectDetectionMetadata', ) @@ -80,7 +80,6 @@ class AutoMlImageObjectDetectionInputs(proto.Message): training before the entire training budget has been used. """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -90,7 +89,9 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 4 MOBILE_TF_HIGH_ACCURACY_1 = 5 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -111,7 +112,6 @@ class AutoMlImageObjectDetectionMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ - class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -120,8 +120,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field( - proto.ENUM, number=2, enum=SuccessfulStopReason, + successful_stop_reason = proto.Field(proto.ENUM, number=2, + enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py index 014df43b2f..28fd9d385d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', manifest={ - "AutoMlImageSegmentation", - "AutoMlImageSegmentationInputs", - "AutoMlImageSegmentationMetadata", + 'AutoMlImageSegmentation', + 'AutoMlImageSegmentationInputs', + 'AutoMlImageSegmentationMetadata', }, ) @@ -39,12 +39,12 @@ class AutoMlImageSegmentation(proto.Message): The metadata information. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlImageSegmentationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlImageSegmentationInputs', ) - metadata = proto.Field( - proto.MESSAGE, number=2, message="AutoMlImageSegmentationMetadata", + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlImageSegmentationMetadata', ) @@ -76,7 +76,6 @@ class AutoMlImageSegmentationInputs(proto.Message): ``base`` model must be in the same Project and Location as the new Model to train, and have the same modelType. """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -84,7 +83,9 @@ class ModelType(proto.Enum): CLOUD_LOW_ACCURACY_1 = 2 MOBILE_TF_LOW_LATENCY_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -105,7 +106,6 @@ class AutoMlImageSegmentationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ - class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -114,8 +114,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field( - proto.ENUM, number=2, enum=SuccessfulStopReason, + successful_stop_reason = proto.Field(proto.ENUM, number=2, + enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py index 19c43929e8..a506fe6493 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py @@ -18,14 +18,16 @@ import proto # type: ignore -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types import ( - export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config, -) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types import export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'AutoMlTables', + 'AutoMlTablesInputs', + 'AutoMlTablesMetadata', + }, ) @@ -39,9 +41,13 @@ class AutoMlTables(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTablesInputs",) + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTablesInputs', + ) - metadata = proto.Field(proto.MESSAGE, number=2, message="AutoMlTablesMetadata",) + metadata = proto.Field(proto.MESSAGE, number=2, + message='AutoMlTablesMetadata', + ) class AutoMlTablesInputs(proto.Message): @@ -146,7 +152,6 @@ class AutoMlTablesInputs(proto.Message): configuration is absent, then the export is not performed. """ - class Transformation(proto.Message): r""" @@ -168,7 +173,6 @@ class Transformation(proto.Message): repeated_text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation): """ - class AutoTransformation(proto.Message): r"""Training pipeline will infer the proper transformation based on the statistic of dataset. @@ -343,76 +347,48 @@ class TextArrayTransformation(proto.Message): column_name = proto.Field(proto.STRING, number=1) - auto = proto.Field( - proto.MESSAGE, - number=1, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.AutoTransformation", + auto = proto.Field(proto.MESSAGE, number=1, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.AutoTransformation', ) - numeric = proto.Field( - proto.MESSAGE, - number=2, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.NumericTransformation", + numeric = proto.Field(proto.MESSAGE, number=2, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.NumericTransformation', ) - categorical = proto.Field( - proto.MESSAGE, - number=3, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.CategoricalTransformation", + categorical = proto.Field(proto.MESSAGE, number=3, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.CategoricalTransformation', ) - timestamp = proto.Field( - proto.MESSAGE, - number=4, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.TimestampTransformation", + timestamp = proto.Field(proto.MESSAGE, number=4, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.TimestampTransformation', ) - text = proto.Field( - proto.MESSAGE, - number=5, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.TextTransformation", + text = proto.Field(proto.MESSAGE, number=5, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.TextTransformation', ) - repeated_numeric = proto.Field( - proto.MESSAGE, - number=6, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.NumericArrayTransformation", + repeated_numeric = proto.Field(proto.MESSAGE, number=6, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.NumericArrayTransformation', ) - repeated_categorical = proto.Field( - proto.MESSAGE, - number=7, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.CategoricalArrayTransformation", + repeated_categorical = proto.Field(proto.MESSAGE, number=7, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.CategoricalArrayTransformation', ) - repeated_text = proto.Field( - proto.MESSAGE, - number=8, - oneof="transformation_detail", - message="AutoMlTablesInputs.Transformation.TextArrayTransformation", + repeated_text = proto.Field(proto.MESSAGE, number=8, oneof='transformation_detail', + message='AutoMlTablesInputs.Transformation.TextArrayTransformation', ) - optimization_objective_recall_value = proto.Field( - proto.FLOAT, number=5, oneof="additional_optimization_objective_config" - ) + optimization_objective_recall_value = proto.Field(proto.FLOAT, number=5, oneof='additional_optimization_objective_config') - optimization_objective_precision_value = proto.Field( - proto.FLOAT, number=6, oneof="additional_optimization_objective_config" - ) + optimization_objective_precision_value = proto.Field(proto.FLOAT, number=6, oneof='additional_optimization_objective_config') prediction_type = proto.Field(proto.STRING, number=1) target_column = proto.Field(proto.STRING, number=2) - transformations = proto.RepeatedField( - proto.MESSAGE, number=3, message=Transformation, + transformations = proto.RepeatedField(proto.MESSAGE, number=3, + message=Transformation, ) optimization_objective = proto.Field(proto.STRING, number=4) @@ -423,9 +399,7 @@ class TextArrayTransformation(proto.Message): weight_column_name = proto.Field(proto.STRING, number=9) - export_evaluated_data_items_config = proto.Field( - proto.MESSAGE, - number=10, + export_evaluated_data_items_config = proto.Field(proto.MESSAGE, number=10, message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py index 9fe6b865c9..dd9c448258 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"AutoMlTextClassification", "AutoMlTextClassificationInputs",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'AutoMlTextClassification', + 'AutoMlTextClassificationInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlTextClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlTextClassificationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTextClassificationInputs', ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py index c7b1fc6dba..d1111f379f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"AutoMlTextExtraction", "AutoMlTextExtractionInputs",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'AutoMlTextExtraction', + 'AutoMlTextExtractionInputs', + }, ) @@ -33,7 +36,9 @@ class AutoMlTextExtraction(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextExtractionInputs",) + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTextExtractionInputs', + ) class AutoMlTextExtractionInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py index 8239b55fdf..06f4fa06f9 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"AutoMlTextSentiment", "AutoMlTextSentimentInputs",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'AutoMlTextSentiment', + 'AutoMlTextSentimentInputs', + }, ) @@ -33,7 +36,9 @@ class AutoMlTextSentiment(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextSentimentInputs",) + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlTextSentimentInputs', + ) class AutoMlTextSentimentInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py index 66448faf01..e795fa10c5 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"AutoMlVideoActionRecognition", "AutoMlVideoActionRecognitionInputs",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'AutoMlVideoActionRecognition', + 'AutoMlVideoActionRecognitionInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlVideoActionRecognition(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlVideoActionRecognitionInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlVideoActionRecognitionInputs', ) @@ -45,14 +48,15 @@ class AutoMlVideoActionRecognitionInputs(proto.Message): model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoActionRecognitionInputs.ModelType): """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 CLOUD = 1 MOBILE_VERSATILE_1 = 2 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py index e1c12eb46c..2d3ffbf007 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"AutoMlVideoClassification", "AutoMlVideoClassificationInputs",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'AutoMlVideoClassification', + 'AutoMlVideoClassificationInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlVideoClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlVideoClassificationInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlVideoClassificationInputs', ) @@ -45,7 +48,6 @@ class AutoMlVideoClassificationInputs(proto.Message): model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoClassificationInputs.ModelType): """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -53,7 +55,9 @@ class ModelType(proto.Enum): MOBILE_VERSATILE_1 = 2 MOBILE_JETSON_VERSATILE_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py index 328e266a3b..adf69eee56 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py @@ -19,8 +19,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'AutoMlVideoObjectTracking', + 'AutoMlVideoObjectTrackingInputs', + }, ) @@ -33,8 +36,8 @@ class AutoMlVideoObjectTracking(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field( - proto.MESSAGE, number=1, message="AutoMlVideoObjectTrackingInputs", + inputs = proto.Field(proto.MESSAGE, number=1, + message='AutoMlVideoObjectTrackingInputs', ) @@ -45,7 +48,6 @@ class AutoMlVideoObjectTrackingInputs(proto.Message): model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoObjectTrackingInputs.ModelType): """ - class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -56,7 +58,9 @@ class ModelType(proto.Enum): MOBILE_JETSON_VERSATILE_1 = 5 MOBILE_JETSON_LOW_LATENCY_1 = 6 - model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + model_type = proto.Field(proto.ENUM, number=1, + enum=ModelType, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py index 9a6195fec2..2770d78441 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", - manifest={"ExportEvaluatedDataItemsConfig",}, + package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + manifest={ + 'ExportEvaluatedDataItemsConfig', + }, ) diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index 1b0c76e834..24c5acb6bb 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -180,166 +180,166 @@ __all__ = ( - "AcceleratorType", - "ActiveLearningConfig", - "Annotation", - "AnnotationSpec", - "AutomaticResources", - "BatchDedicatedResources", - "BatchMigrateResourcesOperationMetadata", - "BatchMigrateResourcesRequest", - "BatchMigrateResourcesResponse", - "BatchPredictionJob", - "BigQueryDestination", - "BigQuerySource", - "CancelBatchPredictionJobRequest", - "CancelCustomJobRequest", - "CancelDataLabelingJobRequest", - "CancelHyperparameterTuningJobRequest", - "CancelTrainingPipelineRequest", - "CompletionStats", - "ContainerRegistryDestination", - "ContainerSpec", - "CreateBatchPredictionJobRequest", - "CreateCustomJobRequest", - "CreateDataLabelingJobRequest", - "CreateDatasetOperationMetadata", - "CreateDatasetRequest", - "CreateEndpointOperationMetadata", - "CreateEndpointRequest", - "CreateHyperparameterTuningJobRequest", - "CreateSpecialistPoolOperationMetadata", - "CreateSpecialistPoolRequest", - "CreateTrainingPipelineRequest", - "CustomJob", - "CustomJobSpec", - "DataItem", - "DataLabelingJob", - "Dataset", - "DatasetServiceClient", - "DedicatedResources", - "DeleteBatchPredictionJobRequest", - "DeleteCustomJobRequest", - "DeleteDataLabelingJobRequest", - "DeleteDatasetRequest", - "DeleteEndpointRequest", - "DeleteHyperparameterTuningJobRequest", - "DeleteModelRequest", - "DeleteOperationMetadata", - "DeleteSpecialistPoolRequest", - "DeleteTrainingPipelineRequest", - "DeployModelOperationMetadata", - "DeployModelRequest", - "DeployModelResponse", - "DeployedModel", - "DeployedModelRef", - "DiskSpec", - "EncryptionSpec", - "Endpoint", - "EndpointServiceClient", - "EnvVar", - "ExportDataConfig", - "ExportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportModelOperationMetadata", - "ExportModelRequest", - "ExportModelResponse", - "FilterSplit", - "FractionSplit", - "GcsDestination", - "GcsSource", - "GenericOperationMetadata", - "GetAnnotationSpecRequest", - "GetBatchPredictionJobRequest", - "GetCustomJobRequest", - "GetDataLabelingJobRequest", - "GetDatasetRequest", - "GetEndpointRequest", - "GetHyperparameterTuningJobRequest", - "GetModelEvaluationRequest", - "GetModelEvaluationSliceRequest", - "GetModelRequest", - "GetSpecialistPoolRequest", - "GetTrainingPipelineRequest", - "HyperparameterTuningJob", - "ImportDataConfig", - "ImportDataOperationMetadata", - "ImportDataRequest", - "ImportDataResponse", - "InputDataConfig", - "JobServiceClient", - "JobState", - "ListAnnotationsRequest", - "ListAnnotationsResponse", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "ListDataItemsRequest", - "ListDataItemsResponse", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "ListDatasetsRequest", - "ListDatasetsResponse", - "ListEndpointsRequest", - "ListEndpointsResponse", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "ListModelsRequest", - "ListModelsResponse", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "MachineSpec", - "ManualBatchTuningParameters", - "Measurement", - "MigratableResource", - "MigrateResourceRequest", - "MigrateResourceResponse", - "MigrationServiceClient", - "Model", - "ModelContainerSpec", - "ModelEvaluation", - "ModelEvaluationSlice", - "ModelServiceClient", - "PipelineServiceClient", - "PipelineState", - "Port", - "PredefinedSplit", - "PredictRequest", - "PredictResponse", - "PredictSchemata", - "PredictionServiceClient", - "PythonPackageSpec", - "ResourcesConsumed", - "SampleConfig", - "Scheduling", - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", - "SpecialistPool", - "StudySpec", - "TimestampSplit", - "TrainingConfig", - "TrainingPipeline", - "Trial", - "UndeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UpdateDatasetRequest", - "UpdateEndpointRequest", - "UpdateModelRequest", - "UpdateSpecialistPoolOperationMetadata", - "UpdateSpecialistPoolRequest", - "UploadModelOperationMetadata", - "UploadModelRequest", - "UploadModelResponse", - "UserActionReference", - "WorkerPoolSpec", - "SpecialistPoolServiceClient", + 'AcceleratorType', + 'ActiveLearningConfig', + 'Annotation', + 'AnnotationSpec', + 'AutomaticResources', + 'BatchDedicatedResources', + 'BatchMigrateResourcesOperationMetadata', + 'BatchMigrateResourcesRequest', + 'BatchMigrateResourcesResponse', + 'BatchPredictionJob', + 'BigQueryDestination', + 'BigQuerySource', + 'CancelBatchPredictionJobRequest', + 'CancelCustomJobRequest', + 'CancelDataLabelingJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CancelTrainingPipelineRequest', + 'CompletionStats', + 'ContainerRegistryDestination', + 'ContainerSpec', + 'CreateBatchPredictionJobRequest', + 'CreateCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'CreateDatasetOperationMetadata', + 'CreateDatasetRequest', + 'CreateEndpointOperationMetadata', + 'CreateEndpointRequest', + 'CreateHyperparameterTuningJobRequest', + 'CreateSpecialistPoolOperationMetadata', + 'CreateSpecialistPoolRequest', + 'CreateTrainingPipelineRequest', + 'CustomJob', + 'CustomJobSpec', + 'DataItem', + 'DataLabelingJob', + 'Dataset', + 'DatasetServiceClient', + 'DedicatedResources', + 'DeleteBatchPredictionJobRequest', + 'DeleteCustomJobRequest', + 'DeleteDataLabelingJobRequest', + 'DeleteDatasetRequest', + 'DeleteEndpointRequest', + 'DeleteHyperparameterTuningJobRequest', + 'DeleteModelRequest', + 'DeleteOperationMetadata', + 'DeleteSpecialistPoolRequest', + 'DeleteTrainingPipelineRequest', + 'DeployModelOperationMetadata', + 'DeployModelRequest', + 'DeployModelResponse', + 'DeployedModel', + 'DeployedModelRef', + 'DiskSpec', + 'EncryptionSpec', + 'Endpoint', + 'EndpointServiceClient', + 'EnvVar', + 'ExportDataConfig', + 'ExportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'ExportModelOperationMetadata', + 'ExportModelRequest', + 'ExportModelResponse', + 'FilterSplit', + 'FractionSplit', + 'GcsDestination', + 'GcsSource', + 'GenericOperationMetadata', + 'GetAnnotationSpecRequest', + 'GetBatchPredictionJobRequest', + 'GetCustomJobRequest', + 'GetDataLabelingJobRequest', + 'GetDatasetRequest', + 'GetEndpointRequest', + 'GetHyperparameterTuningJobRequest', + 'GetModelEvaluationRequest', + 'GetModelEvaluationSliceRequest', + 'GetModelRequest', + 'GetSpecialistPoolRequest', + 'GetTrainingPipelineRequest', + 'HyperparameterTuningJob', + 'ImportDataConfig', + 'ImportDataOperationMetadata', + 'ImportDataRequest', + 'ImportDataResponse', + 'InputDataConfig', + 'JobServiceClient', + 'JobState', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'ListModelsRequest', + 'ListModelsResponse', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'MachineSpec', + 'ManualBatchTuningParameters', + 'Measurement', + 'MigratableResource', + 'MigrateResourceRequest', + 'MigrateResourceResponse', + 'MigrationServiceClient', + 'Model', + 'ModelContainerSpec', + 'ModelEvaluation', + 'ModelEvaluationSlice', + 'ModelServiceClient', + 'PipelineServiceClient', + 'PipelineState', + 'Port', + 'PredefinedSplit', + 'PredictRequest', + 'PredictResponse', + 'PredictSchemata', + 'PredictionServiceClient', + 'PythonPackageSpec', + 'ResourcesConsumed', + 'SampleConfig', + 'Scheduling', + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'SpecialistPool', + 'StudySpec', + 'TimestampSplit', + 'TrainingConfig', + 'TrainingPipeline', + 'Trial', + 'UndeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UpdateDatasetRequest', + 'UpdateEndpointRequest', + 'UpdateModelRequest', + 'UpdateSpecialistPoolOperationMetadata', + 'UpdateSpecialistPoolRequest', + 'UploadModelOperationMetadata', + 'UploadModelRequest', + 'UploadModelResponse', + 'UserActionReference', + 'WorkerPoolSpec', +'SpecialistPoolServiceClient', ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py index 597f654cb9..9d1f004f6a 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import DatasetServiceAsyncClient __all__ = ( - "DatasetServiceClient", - "DatasetServiceAsyncClient", + 'DatasetServiceClient', + 'DatasetServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index a07ee32dfd..950d920c5a 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,42 +60,26 @@ class DatasetServiceAsyncClient: annotation_path = staticmethod(DatasetServiceClient.annotation_path) parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) - parse_annotation_spec_path = staticmethod( - DatasetServiceClient.parse_annotation_spec_path - ) + parse_annotation_spec_path = staticmethod(DatasetServiceClient.parse_annotation_spec_path) data_item_path = staticmethod(DatasetServiceClient.data_item_path) parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) dataset_path = staticmethod(DatasetServiceClient.dataset_path) parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) - common_billing_account_path = staticmethod( - DatasetServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - DatasetServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(DatasetServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(DatasetServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - DatasetServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(DatasetServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - DatasetServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - DatasetServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(DatasetServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(DatasetServiceClient.parse_common_organization_path) common_project_path = staticmethod(DatasetServiceClient.common_project_path) - parse_common_project_path = staticmethod( - DatasetServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(DatasetServiceClient.parse_common_project_path) common_location_path = staticmethod(DatasetServiceClient.common_location_path) - parse_common_location_path = staticmethod( - DatasetServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(DatasetServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -138,18 +122,14 @@ def transport(self) -> DatasetServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) - ) + get_transport_class = functools.partial(type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -188,18 +168,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_dataset( - self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_dataset(self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a Dataset. Args: @@ -240,10 +220,8 @@ async def create_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.CreateDatasetRequest(request) @@ -266,11 +244,18 @@ async def create_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -283,15 +268,14 @@ async def create_dataset( # Done; return the response. return response - async def get_dataset( - self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + async def get_dataset(self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -323,10 +307,8 @@ async def get_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.GetDatasetRequest(request) @@ -347,25 +329,31 @@ async def get_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def update_dataset( - self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + async def update_dataset(self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -410,10 +398,8 @@ async def update_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.UpdateDatasetRequest(request) @@ -436,26 +422,30 @@ async def update_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("dataset.name", request.dataset.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('dataset.name', request.dataset.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_datasets( - self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsAsyncPager: + async def list_datasets(self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: r"""Lists Datasets in a Location. Args: @@ -490,10 +480,8 @@ async def list_datasets( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ListDatasetsRequest(request) @@ -514,30 +502,39 @@ async def list_datasets( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDatasetsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_dataset( - self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_dataset(self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Dataset. Args: @@ -583,10 +580,8 @@ async def delete_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.DeleteDatasetRequest(request) @@ -607,11 +602,18 @@ async def delete_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -624,16 +626,15 @@ async def delete_dataset( # Done; return the response. return response - async def import_data( - self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def import_data(self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Imports data into a Dataset. Args: @@ -677,10 +678,8 @@ async def import_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ImportDataRequest(request) @@ -704,11 +703,18 @@ async def import_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -721,16 +727,15 @@ async def import_data( # Done; return the response. return response - async def export_data( - self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_data(self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports data from a Dataset. Args: @@ -773,10 +778,8 @@ async def export_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ExportDataRequest(request) @@ -799,11 +802,18 @@ async def export_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -816,15 +826,14 @@ async def export_data( # Done; return the response. return response - async def list_data_items( - self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsAsyncPager: + async def list_data_items(self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: r"""Lists DataItems in a Dataset. Args: @@ -860,10 +869,8 @@ async def list_data_items( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ListDataItemsRequest(request) @@ -884,30 +891,39 @@ async def list_data_items( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataItemsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def get_annotation_spec( - self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + async def get_annotation_spec(self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -941,10 +957,8 @@ async def get_annotation_spec( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.GetAnnotationSpecRequest(request) @@ -965,24 +979,30 @@ async def get_annotation_spec( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_annotations( - self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsAsyncPager: + async def list_annotations(self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1019,10 +1039,8 @@ async def list_annotations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ListAnnotationsRequest(request) @@ -1043,30 +1061,47 @@ async def list_annotations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListAnnotationsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("DatasetServiceAsyncClient",) +__all__ = ( + 'DatasetServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index 160a2049b8..52109ac90b 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,14 +60,13 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry['grpc'] = DatasetServiceGrpcTransport + _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry["grpc"] = DatasetServiceGrpcTransport - _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -118,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -153,8 +152,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: DatasetServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,149 +169,110 @@ def transport(self) -> DatasetServiceTransport: return self._transport @staticmethod - def annotation_path( - project: str, location: str, dataset: str, data_item: str, annotation: str, - ) -> str: + def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: """Return a fully-qualified annotation string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( - project=project, - location=location, - dataset=dataset, - data_item=data_item, - annotation=annotation, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) @staticmethod - def parse_annotation_path(path: str) -> Dict[str, str]: + def parse_annotation_path(path: str) -> Dict[str,str]: """Parse a annotation path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def annotation_spec_path( - project: str, location: str, dataset: str, annotation_spec: str, - ) -> str: + def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: """Return a fully-qualified annotation_spec string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( - project=project, - location=location, - dataset=dataset, - annotation_spec=annotation_spec, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) @staticmethod - def parse_annotation_spec_path(path: str) -> Dict[str, str]: + def parse_annotation_spec_path(path: str) -> Dict[str,str]: """Parse a annotation_spec path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def data_item_path( - project: str, location: str, dataset: str, data_item: str, - ) -> str: + def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: """Return a fully-qualified data_item string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( - project=project, location=location, dataset=dataset, data_item=data_item, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) @staticmethod - def parse_data_item_path(path: str) -> Dict[str, str]: + def parse_data_item_path(path: str) -> Dict[str,str]: """Parse a data_item path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -355,9 +316,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -367,9 +326,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -381,9 +338,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -395,10 +350,8 @@ def __init__( if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -417,16 +370,15 @@ def __init__( client_info=client_info, ) - def create_dataset( - self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_dataset(self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a Dataset. Args: @@ -467,10 +419,8 @@ def create_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.CreateDatasetRequest. @@ -494,11 +444,18 @@ def create_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -511,15 +468,14 @@ def create_dataset( # Done; return the response. return response - def get_dataset( - self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset(self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -551,10 +507,8 @@ def get_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetDatasetRequest. @@ -576,25 +530,31 @@ def get_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def update_dataset( - self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset(self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -639,10 +599,8 @@ def update_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.UpdateDatasetRequest. @@ -666,26 +624,30 @@ def update_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("dataset.name", request.dataset.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('dataset.name', request.dataset.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_datasets( - self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets(self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -720,10 +682,8 @@ def list_datasets( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDatasetsRequest. @@ -745,30 +705,39 @@ def list_datasets( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_dataset( - self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_dataset(self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Dataset. Args: @@ -814,10 +783,8 @@ def delete_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.DeleteDatasetRequest. @@ -839,11 +806,18 @@ def delete_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -856,16 +830,15 @@ def delete_dataset( # Done; return the response. return response - def import_data( - self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def import_data(self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Imports data into a Dataset. Args: @@ -909,10 +882,8 @@ def import_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ImportDataRequest. @@ -936,11 +907,18 @@ def import_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -953,16 +931,15 @@ def import_data( # Done; return the response. return response - def export_data( - self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_data(self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports data from a Dataset. Args: @@ -1005,10 +982,8 @@ def export_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ExportDataRequest. @@ -1032,11 +1007,18 @@ def export_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1049,15 +1031,14 @@ def export_data( # Done; return the response. return response - def list_data_items( - self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items(self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -1093,10 +1074,8 @@ def list_data_items( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDataItemsRequest. @@ -1118,30 +1097,39 @@ def list_data_items( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec( - self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec(self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -1175,10 +1163,8 @@ def get_annotation_spec( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetAnnotationSpecRequest. @@ -1200,24 +1186,30 @@ def get_annotation_spec( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_annotations( - self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations(self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1254,10 +1246,8 @@ def list_annotations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListAnnotationsRequest. @@ -1279,30 +1269,47 @@ def list_annotations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("DatasetServiceClient",) +__all__ = ( + 'DatasetServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py index c3f8265b6e..3439dc331c 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1.types import annotation from google.cloud.aiplatform_v1.types import data_item @@ -49,15 +40,12 @@ class ListDatasetsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., dataset_service.ListDatasetsResponse], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., dataset_service.ListDatasetsResponse], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -91,7 +79,7 @@ def __iter__(self) -> Iterable[dataset.Dataset]: yield from page.datasets def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDatasetsAsyncPager: @@ -111,15 +99,12 @@ class ListDatasetsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -157,7 +142,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataItemsPager: @@ -177,15 +162,12 @@ class ListDataItemsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., dataset_service.ListDataItemsResponse], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., dataset_service.ListDataItemsResponse], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -219,7 +201,7 @@ def __iter__(self) -> Iterable[data_item.DataItem]: yield from page.data_items def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataItemsAsyncPager: @@ -239,15 +221,12 @@ class ListDataItemsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -285,7 +264,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListAnnotationsPager: @@ -305,15 +284,12 @@ class ListAnnotationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., dataset_service.ListAnnotationsResponse], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., dataset_service.ListAnnotationsResponse], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -347,7 +323,7 @@ def __iter__(self) -> Iterable[annotation.Annotation]: yield from page.annotations def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListAnnotationsAsyncPager: @@ -367,15 +343,12 @@ class ListAnnotationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -413,4 +386,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py index a4461d2ced..5f02a0f0d9 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] -_transport_registry["grpc"] = DatasetServiceGrpcTransport -_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = DatasetServiceGrpcTransport +_transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport __all__ = ( - "DatasetServiceTransport", - "DatasetServiceGrpcTransport", - "DatasetServiceGrpcAsyncIOTransport", + 'DatasetServiceTransport', + 'DatasetServiceGrpcTransport', + 'DatasetServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py index 2ab4419d03..15daeb6369 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -74,73 +74,92 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_dataset: gapic_v1.method.wrap_method( - self.create_dataset, default_timeout=None, client_info=client_info, + self.create_dataset, + default_timeout=None, + client_info=client_info, ), self.get_dataset: gapic_v1.method.wrap_method( - self.get_dataset, default_timeout=None, client_info=client_info, + self.get_dataset, + default_timeout=None, + client_info=client_info, ), self.update_dataset: gapic_v1.method.wrap_method( - self.update_dataset, default_timeout=None, client_info=client_info, + self.update_dataset, + default_timeout=None, + client_info=client_info, ), self.list_datasets: gapic_v1.method.wrap_method( - self.list_datasets, default_timeout=None, client_info=client_info, + self.list_datasets, + default_timeout=None, + client_info=client_info, ), self.delete_dataset: gapic_v1.method.wrap_method( - self.delete_dataset, default_timeout=None, client_info=client_info, + self.delete_dataset, + default_timeout=None, + client_info=client_info, ), self.import_data: gapic_v1.method.wrap_method( - self.import_data, default_timeout=None, client_info=client_info, + self.import_data, + default_timeout=None, + client_info=client_info, ), self.export_data: gapic_v1.method.wrap_method( - self.export_data, default_timeout=None, client_info=client_info, + self.export_data, + default_timeout=None, + client_info=client_info, ), self.list_data_items: gapic_v1.method.wrap_method( - self.list_data_items, default_timeout=None, client_info=client_info, + self.list_data_items, + default_timeout=None, + client_info=client_info, ), self.get_annotation_spec: gapic_v1.method.wrap_method( - self.get_annotation_spec, default_timeout=None, client_info=client_info, + self.get_annotation_spec, + default_timeout=None, + client_info=client_info, ), self.list_annotations: gapic_v1.method.wrap_method( - self.list_annotations, default_timeout=None, client_info=client_info, + self.list_annotations, + default_timeout=None, + client_info=client_info, ), + } @property @@ -149,106 +168,96 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_dataset( - self, - ) -> typing.Callable[ - [dataset_service.CreateDatasetRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def create_dataset(self) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_dataset( - self, - ) -> typing.Callable[ - [dataset_service.GetDatasetRequest], - typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], - ]: + def get_dataset(self) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[ + dataset.Dataset, + typing.Awaitable[dataset.Dataset] + ]]: raise NotImplementedError() @property - def update_dataset( - self, - ) -> typing.Callable[ - [dataset_service.UpdateDatasetRequest], - typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], - ]: + def update_dataset(self) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[ + gca_dataset.Dataset, + typing.Awaitable[gca_dataset.Dataset] + ]]: raise NotImplementedError() @property - def list_datasets( - self, - ) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], - typing.Union[ - dataset_service.ListDatasetsResponse, - typing.Awaitable[dataset_service.ListDatasetsResponse], - ], - ]: + def list_datasets(self) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse] + ]]: raise NotImplementedError() @property - def delete_dataset( - self, - ) -> typing.Callable[ - [dataset_service.DeleteDatasetRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_dataset(self) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def import_data( - self, - ) -> typing.Callable[ - [dataset_service.ImportDataRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def import_data(self) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def export_data( - self, - ) -> typing.Callable[ - [dataset_service.ExportDataRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def export_data(self) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def list_data_items( - self, - ) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], - typing.Union[ - dataset_service.ListDataItemsResponse, - typing.Awaitable[dataset_service.ListDataItemsResponse], - ], - ]: + def list_data_items(self) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse] + ]]: raise NotImplementedError() @property - def get_annotation_spec( - self, - ) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], - typing.Union[ - annotation_spec.AnnotationSpec, - typing.Awaitable[annotation_spec.AnnotationSpec], - ], - ]: + def get_annotation_spec(self) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec] + ]]: raise NotImplementedError() @property - def list_annotations( - self, - ) -> typing.Callable[ - [dataset_service.ListAnnotationsRequest], - typing.Union[ - dataset_service.ListAnnotationsResponse, - typing.Awaitable[dataset_service.ListAnnotationsResponse], - ], - ]: + def list_annotations(self) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse] + ]]: raise NotImplementedError() -__all__ = ("DatasetServiceTransport",) +__all__ = ( + 'DatasetServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py index 20a01deb79..96efd8e427 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -46,24 +46,21 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -109,7 +106,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -117,70 +117,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -188,32 +168,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -243,12 +211,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -260,15 +229,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_dataset( - self, - ) -> Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: + def create_dataset(self) -> Callable[ + [dataset_service.CreateDatasetRequest], + operations.Operation]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -283,18 +254,18 @@ def create_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_dataset" not in self._stubs: - self._stubs["create_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/CreateDataset", + if 'create_dataset' not in self._stubs: + self._stubs['create_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/CreateDataset', request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_dataset"] + return self._stubs['create_dataset'] @property - def get_dataset( - self, - ) -> Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: + def get_dataset(self) -> Callable[ + [dataset_service.GetDatasetRequest], + dataset.Dataset]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -309,18 +280,18 @@ def get_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_dataset" not in self._stubs: - self._stubs["get_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/GetDataset", + if 'get_dataset' not in self._stubs: + self._stubs['get_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/GetDataset', request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs["get_dataset"] + return self._stubs['get_dataset'] @property - def update_dataset( - self, - ) -> Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: + def update_dataset(self) -> Callable[ + [dataset_service.UpdateDatasetRequest], + gca_dataset.Dataset]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -335,20 +306,18 @@ def update_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_dataset" not in self._stubs: - self._stubs["update_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/UpdateDataset", + if 'update_dataset' not in self._stubs: + self._stubs['update_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/UpdateDataset', request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs["update_dataset"] + return self._stubs['update_dataset'] @property - def list_datasets( - self, - ) -> Callable[ - [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse - ]: + def list_datasets(self) -> Callable[ + [dataset_service.ListDatasetsRequest], + dataset_service.ListDatasetsResponse]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -363,18 +332,18 @@ def list_datasets( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_datasets" not in self._stubs: - self._stubs["list_datasets"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ListDatasets", + if 'list_datasets' not in self._stubs: + self._stubs['list_datasets'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ListDatasets', request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs["list_datasets"] + return self._stubs['list_datasets'] @property - def delete_dataset( - self, - ) -> Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: + def delete_dataset(self) -> Callable[ + [dataset_service.DeleteDatasetRequest], + operations.Operation]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -389,18 +358,18 @@ def delete_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_dataset" not in self._stubs: - self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/DeleteDataset", + if 'delete_dataset' not in self._stubs: + self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/DeleteDataset', request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_dataset"] + return self._stubs['delete_dataset'] @property - def import_data( - self, - ) -> Callable[[dataset_service.ImportDataRequest], operations.Operation]: + def import_data(self) -> Callable[ + [dataset_service.ImportDataRequest], + operations.Operation]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -415,18 +384,18 @@ def import_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "import_data" not in self._stubs: - self._stubs["import_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ImportData", + if 'import_data' not in self._stubs: + self._stubs['import_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ImportData', request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["import_data"] + return self._stubs['import_data'] @property - def export_data( - self, - ) -> Callable[[dataset_service.ExportDataRequest], operations.Operation]: + def export_data(self) -> Callable[ + [dataset_service.ExportDataRequest], + operations.Operation]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -441,20 +410,18 @@ def export_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_data" not in self._stubs: - self._stubs["export_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ExportData", + if 'export_data' not in self._stubs: + self._stubs['export_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ExportData', request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_data"] + return self._stubs['export_data'] @property - def list_data_items( - self, - ) -> Callable[ - [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse - ]: + def list_data_items(self) -> Callable[ + [dataset_service.ListDataItemsRequest], + dataset_service.ListDataItemsResponse]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -469,20 +436,18 @@ def list_data_items( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_items" not in self._stubs: - self._stubs["list_data_items"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ListDataItems", + if 'list_data_items' not in self._stubs: + self._stubs['list_data_items'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ListDataItems', request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs["list_data_items"] + return self._stubs['list_data_items'] @property - def get_annotation_spec( - self, - ) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec - ]: + def get_annotation_spec(self) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + annotation_spec.AnnotationSpec]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -497,21 +462,18 @@ def get_annotation_spec( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_annotation_spec" not in self._stubs: - self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec", + if 'get_annotation_spec' not in self._stubs: + self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec', request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs["get_annotation_spec"] + return self._stubs['get_annotation_spec'] @property - def list_annotations( - self, - ) -> Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse, - ]: + def list_annotations(self) -> Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -526,13 +488,15 @@ def list_annotations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_annotations" not in self._stubs: - self._stubs["list_annotations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ListAnnotations", + if 'list_annotations' not in self._stubs: + self._stubs['list_annotations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ListAnnotations', request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs["list_annotations"] + return self._stubs['list_annotations'] -__all__ = ("DatasetServiceGrpcTransport",) +__all__ = ( + 'DatasetServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py index bcf3331d6b..924299a2f7 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import annotation_spec @@ -53,18 +53,16 @@ class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -90,24 +88,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -142,10 +138,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -154,7 +150,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -162,70 +161,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -233,18 +212,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -273,11 +242,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_dataset( - self, - ) -> Callable[ - [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] - ]: + def create_dataset(self) -> Callable[ + [dataset_service.CreateDatasetRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -292,18 +259,18 @@ def create_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_dataset" not in self._stubs: - self._stubs["create_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/CreateDataset", + if 'create_dataset' not in self._stubs: + self._stubs['create_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/CreateDataset', request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_dataset"] + return self._stubs['create_dataset'] @property - def get_dataset( - self, - ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: + def get_dataset(self) -> Callable[ + [dataset_service.GetDatasetRequest], + Awaitable[dataset.Dataset]]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -318,20 +285,18 @@ def get_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_dataset" not in self._stubs: - self._stubs["get_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/GetDataset", + if 'get_dataset' not in self._stubs: + self._stubs['get_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/GetDataset', request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs["get_dataset"] + return self._stubs['get_dataset'] @property - def update_dataset( - self, - ) -> Callable[ - [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] - ]: + def update_dataset(self) -> Callable[ + [dataset_service.UpdateDatasetRequest], + Awaitable[gca_dataset.Dataset]]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -346,21 +311,18 @@ def update_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_dataset" not in self._stubs: - self._stubs["update_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/UpdateDataset", + if 'update_dataset' not in self._stubs: + self._stubs['update_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/UpdateDataset', request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs["update_dataset"] + return self._stubs['update_dataset'] @property - def list_datasets( - self, - ) -> Callable[ - [dataset_service.ListDatasetsRequest], - Awaitable[dataset_service.ListDatasetsResponse], - ]: + def list_datasets(self) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse]]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -375,20 +337,18 @@ def list_datasets( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_datasets" not in self._stubs: - self._stubs["list_datasets"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ListDatasets", + if 'list_datasets' not in self._stubs: + self._stubs['list_datasets'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ListDatasets', request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs["list_datasets"] + return self._stubs['list_datasets'] @property - def delete_dataset( - self, - ) -> Callable[ - [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] - ]: + def delete_dataset(self) -> Callable[ + [dataset_service.DeleteDatasetRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -403,18 +363,18 @@ def delete_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_dataset" not in self._stubs: - self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/DeleteDataset", + if 'delete_dataset' not in self._stubs: + self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/DeleteDataset', request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_dataset"] + return self._stubs['delete_dataset'] @property - def import_data( - self, - ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: + def import_data(self) -> Callable[ + [dataset_service.ImportDataRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -429,18 +389,18 @@ def import_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "import_data" not in self._stubs: - self._stubs["import_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ImportData", + if 'import_data' not in self._stubs: + self._stubs['import_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ImportData', request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["import_data"] + return self._stubs['import_data'] @property - def export_data( - self, - ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: + def export_data(self) -> Callable[ + [dataset_service.ExportDataRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -455,21 +415,18 @@ def export_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_data" not in self._stubs: - self._stubs["export_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ExportData", + if 'export_data' not in self._stubs: + self._stubs['export_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ExportData', request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_data"] + return self._stubs['export_data'] @property - def list_data_items( - self, - ) -> Callable[ - [dataset_service.ListDataItemsRequest], - Awaitable[dataset_service.ListDataItemsResponse], - ]: + def list_data_items(self) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse]]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -484,21 +441,18 @@ def list_data_items( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_items" not in self._stubs: - self._stubs["list_data_items"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ListDataItems", + if 'list_data_items' not in self._stubs: + self._stubs['list_data_items'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ListDataItems', request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs["list_data_items"] + return self._stubs['list_data_items'] @property - def get_annotation_spec( - self, - ) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - Awaitable[annotation_spec.AnnotationSpec], - ]: + def get_annotation_spec(self) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec]]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -513,21 +467,18 @@ def get_annotation_spec( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_annotation_spec" not in self._stubs: - self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec", + if 'get_annotation_spec' not in self._stubs: + self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec', request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs["get_annotation_spec"] + return self._stubs['get_annotation_spec'] @property - def list_annotations( - self, - ) -> Callable[ - [dataset_service.ListAnnotationsRequest], - Awaitable[dataset_service.ListAnnotationsResponse], - ]: + def list_annotations(self) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse]]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -542,13 +493,15 @@ def list_annotations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_annotations" not in self._stubs: - self._stubs["list_annotations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.DatasetService/ListAnnotations", + if 'list_annotations' not in self._stubs: + self._stubs['list_annotations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.DatasetService/ListAnnotations', request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs["list_annotations"] + return self._stubs['list_annotations'] -__all__ = ("DatasetServiceGrpcAsyncIOTransport",) +__all__ = ( + 'DatasetServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py index 035a5b2388..e4f3dcfbcf 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import EndpointServiceAsyncClient __all__ = ( - "EndpointServiceClient", - "EndpointServiceAsyncClient", + 'EndpointServiceClient', + 'EndpointServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index 13f099328b..244c35bcba 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -58,34 +58,20 @@ class EndpointServiceAsyncClient: model_path = staticmethod(EndpointServiceClient.model_path) parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) - common_billing_account_path = staticmethod( - EndpointServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - EndpointServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(EndpointServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(EndpointServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - EndpointServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(EndpointServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - EndpointServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - EndpointServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(EndpointServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(EndpointServiceClient.parse_common_organization_path) common_project_path = staticmethod(EndpointServiceClient.common_project_path) - parse_common_project_path = staticmethod( - EndpointServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(EndpointServiceClient.parse_common_project_path) common_location_path = staticmethod(EndpointServiceClient.common_location_path) - parse_common_location_path = staticmethod( - EndpointServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(EndpointServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -128,18 +114,14 @@ def transport(self) -> EndpointServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) - ) + get_transport_class = functools.partial(type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -178,18 +160,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_endpoint( - self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_endpoint(self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates an Endpoint. Args: @@ -229,10 +211,8 @@ async def create_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.CreateEndpointRequest(request) @@ -255,11 +235,18 @@ async def create_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -272,15 +259,14 @@ async def create_endpoint( # Done; return the response. return response - async def get_endpoint( - self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + async def get_endpoint(self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -313,10 +299,8 @@ async def get_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.GetEndpointRequest(request) @@ -337,24 +321,30 @@ async def get_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_endpoints( - self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsAsyncPager: + async def list_endpoints(self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: r"""Lists Endpoints in a Location. Args: @@ -390,10 +380,8 @@ async def list_endpoints( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.ListEndpointsRequest(request) @@ -414,31 +402,40 @@ async def list_endpoints( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListEndpointsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def update_endpoint( - self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + async def update_endpoint(self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -478,10 +475,8 @@ async def update_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.UpdateEndpointRequest(request) @@ -504,26 +499,30 @@ async def update_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("endpoint.name", request.endpoint.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint.name', request.endpoint.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def delete_endpoint( - self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_endpoint(self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes an Endpoint. Args: @@ -569,10 +568,8 @@ async def delete_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.DeleteEndpointRequest(request) @@ -593,11 +590,18 @@ async def delete_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -610,19 +614,16 @@ async def delete_endpoint( # Done; return the response. return response - async def deploy_model( - self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[ - endpoint_service.DeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def deploy_model(self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -691,10 +692,8 @@ async def deploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.DeployModelRequest(request) @@ -720,11 +719,18 @@ async def deploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -737,19 +743,16 @@ async def deploy_model( # Done; return the response. return response - async def undeploy_model( - self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[ - endpoint_service.UndeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def undeploy_model(self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -809,10 +812,8 @@ async def undeploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.UndeployModelRequest(request) @@ -838,11 +839,18 @@ async def undeploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -856,14 +864,21 @@ async def undeploy_model( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("EndpointServiceAsyncClient",) +__all__ = ( + 'EndpointServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index de54b0b9b5..3b78f5902e 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -56,14 +56,13 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry['grpc'] = EndpointServiceGrpcTransport + _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry["grpc"] = EndpointServiceGrpcTransport - _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -114,7 +113,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -149,8 +148,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: EndpointServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -165,104 +165,88 @@ def transport(self) -> EndpointServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -306,9 +290,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -318,9 +300,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -332,9 +312,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -346,10 +324,8 @@ def __init__( if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -368,16 +344,15 @@ def __init__( client_info=client_info, ) - def create_endpoint( - self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_endpoint(self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates an Endpoint. Args: @@ -417,10 +392,8 @@ def create_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.CreateEndpointRequest. @@ -444,11 +417,18 @@ def create_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -461,15 +441,14 @@ def create_endpoint( # Done; return the response. return response - def get_endpoint( - self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint(self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -502,10 +481,8 @@ def get_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.GetEndpointRequest. @@ -527,24 +504,30 @@ def get_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_endpoints( - self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints(self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -580,10 +563,8 @@ def list_endpoints( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.ListEndpointsRequest. @@ -605,31 +586,40 @@ def list_endpoints( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def update_endpoint( - self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint(self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -669,10 +659,8 @@ def update_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UpdateEndpointRequest. @@ -696,26 +684,30 @@ def update_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("endpoint.name", request.endpoint.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint.name', request.endpoint.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def delete_endpoint( - self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_endpoint(self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes an Endpoint. Args: @@ -761,10 +753,8 @@ def delete_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeleteEndpointRequest. @@ -786,11 +776,18 @@ def delete_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -803,19 +800,16 @@ def delete_endpoint( # Done; return the response. return response - def deploy_model( - self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[ - endpoint_service.DeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def deploy_model(self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -884,10 +878,8 @@ def deploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeployModelRequest. @@ -913,11 +905,18 @@ def deploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -930,19 +929,16 @@ def deploy_model( # Done; return the response. return response - def undeploy_model( - self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[ - endpoint_service.UndeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def undeploy_model(self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -1002,10 +998,8 @@ def undeploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UndeployModelRequest. @@ -1031,11 +1025,18 @@ def undeploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1049,14 +1050,21 @@ def undeploy_model( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("EndpointServiceClient",) +__all__ = ( + 'EndpointServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py index c22df91c8c..154c455826 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1.types import endpoint from google.cloud.aiplatform_v1.types import endpoint_service @@ -47,15 +38,12 @@ class ListEndpointsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., endpoint_service.ListEndpointsResponse], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., endpoint_service.ListEndpointsResponse], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: yield from page.endpoints def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListEndpointsAsyncPager: @@ -109,15 +97,12 @@ class ListEndpointsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -155,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py index 3d0695461d..eb2ef767fe 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] -_transport_registry["grpc"] = EndpointServiceGrpcTransport -_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = EndpointServiceGrpcTransport +_transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport __all__ = ( - "EndpointServiceTransport", - "EndpointServiceGrpcTransport", - "EndpointServiceGrpcAsyncIOTransport", + 'EndpointServiceTransport', + 'EndpointServiceGrpcTransport', + 'EndpointServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py index 728c38fec3..43520356ad 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,29 +35,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -73,64 +73,77 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_endpoint: gapic_v1.method.wrap_method( - self.create_endpoint, default_timeout=None, client_info=client_info, + self.create_endpoint, + default_timeout=None, + client_info=client_info, ), self.get_endpoint: gapic_v1.method.wrap_method( - self.get_endpoint, default_timeout=None, client_info=client_info, + self.get_endpoint, + default_timeout=None, + client_info=client_info, ), self.list_endpoints: gapic_v1.method.wrap_method( - self.list_endpoints, default_timeout=None, client_info=client_info, + self.list_endpoints, + default_timeout=None, + client_info=client_info, ), self.update_endpoint: gapic_v1.method.wrap_method( - self.update_endpoint, default_timeout=None, client_info=client_info, + self.update_endpoint, + default_timeout=None, + client_info=client_info, ), self.delete_endpoint: gapic_v1.method.wrap_method( - self.delete_endpoint, default_timeout=None, client_info=client_info, + self.delete_endpoint, + default_timeout=None, + client_info=client_info, ), self.deploy_model: gapic_v1.method.wrap_method( - self.deploy_model, default_timeout=None, client_info=client_info, + self.deploy_model, + default_timeout=None, + client_info=client_info, ), self.undeploy_model: gapic_v1.method.wrap_method( - self.undeploy_model, default_timeout=None, client_info=client_info, + self.undeploy_model, + default_timeout=None, + client_info=client_info, ), + } @property @@ -139,70 +152,69 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def create_endpoint(self) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.GetEndpointRequest], - typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], - ]: + def get_endpoint(self) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[ + endpoint.Endpoint, + typing.Awaitable[endpoint.Endpoint] + ]]: raise NotImplementedError() @property - def list_endpoints( - self, - ) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], - typing.Union[ - endpoint_service.ListEndpointsResponse, - typing.Awaitable[endpoint_service.ListEndpointsResponse], - ], - ]: + def list_endpoints(self) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse] + ]]: raise NotImplementedError() @property - def update_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], - typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], - ]: + def update_endpoint(self) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[ + gca_endpoint.Endpoint, + typing.Awaitable[gca_endpoint.Endpoint] + ]]: raise NotImplementedError() @property - def delete_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_endpoint(self) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def deploy_model( - self, - ) -> typing.Callable[ - [endpoint_service.DeployModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def deploy_model(self) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def undeploy_model( - self, - ) -> typing.Callable[ - [endpoint_service.UndeployModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def undeploy_model(self) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() -__all__ = ("EndpointServiceTransport",) +__all__ = ( + 'EndpointServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py index d2c13c3fe7..448aa173b9 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -45,24 +45,21 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -108,7 +105,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -116,70 +116,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -187,32 +167,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -242,12 +210,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -259,15 +228,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_endpoint( - self, - ) -> Callable[[endpoint_service.CreateEndpointRequest], operations.Operation]: + def create_endpoint(self) -> Callable[ + [endpoint_service.CreateEndpointRequest], + operations.Operation]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -282,18 +253,18 @@ def create_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_endpoint" not in self._stubs: - self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint", + if 'create_endpoint' not in self._stubs: + self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint', request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_endpoint"] + return self._stubs['create_endpoint'] @property - def get_endpoint( - self, - ) -> Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: + def get_endpoint(self) -> Callable[ + [endpoint_service.GetEndpointRequest], + endpoint.Endpoint]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -308,20 +279,18 @@ def get_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_endpoint" not in self._stubs: - self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/GetEndpoint", + if 'get_endpoint' not in self._stubs: + self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/GetEndpoint', request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs["get_endpoint"] + return self._stubs['get_endpoint'] @property - def list_endpoints( - self, - ) -> Callable[ - [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse - ]: + def list_endpoints(self) -> Callable[ + [endpoint_service.ListEndpointsRequest], + endpoint_service.ListEndpointsResponse]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -336,18 +305,18 @@ def list_endpoints( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_endpoints" not in self._stubs: - self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/ListEndpoints", + if 'list_endpoints' not in self._stubs: + self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/ListEndpoints', request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs["list_endpoints"] + return self._stubs['list_endpoints'] @property - def update_endpoint( - self, - ) -> Callable[[endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint]: + def update_endpoint(self) -> Callable[ + [endpoint_service.UpdateEndpointRequest], + gca_endpoint.Endpoint]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -362,18 +331,18 @@ def update_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_endpoint" not in self._stubs: - self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint", + if 'update_endpoint' not in self._stubs: + self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint', request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs["update_endpoint"] + return self._stubs['update_endpoint'] @property - def delete_endpoint( - self, - ) -> Callable[[endpoint_service.DeleteEndpointRequest], operations.Operation]: + def delete_endpoint(self) -> Callable[ + [endpoint_service.DeleteEndpointRequest], + operations.Operation]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -388,18 +357,18 @@ def delete_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_endpoint" not in self._stubs: - self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint", + if 'delete_endpoint' not in self._stubs: + self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint', request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_endpoint"] + return self._stubs['delete_endpoint'] @property - def deploy_model( - self, - ) -> Callable[[endpoint_service.DeployModelRequest], operations.Operation]: + def deploy_model(self) -> Callable[ + [endpoint_service.DeployModelRequest], + operations.Operation]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -415,18 +384,18 @@ def deploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "deploy_model" not in self._stubs: - self._stubs["deploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/DeployModel", + if 'deploy_model' not in self._stubs: + self._stubs['deploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/DeployModel', request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["deploy_model"] + return self._stubs['deploy_model'] @property - def undeploy_model( - self, - ) -> Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: + def undeploy_model(self) -> Callable[ + [endpoint_service.UndeployModelRequest], + operations.Operation]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -443,13 +412,15 @@ def undeploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "undeploy_model" not in self._stubs: - self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/UndeployModel", + if 'undeploy_model' not in self._stubs: + self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/UndeployModel', request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["undeploy_model"] + return self._stubs['undeploy_model'] -__all__ = ("EndpointServiceGrpcTransport",) +__all__ = ( + 'EndpointServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py index ef97ba490f..14e2735edd 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import endpoint @@ -52,18 +52,16 @@ class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -89,24 +87,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -141,10 +137,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -153,7 +149,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -161,70 +160,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -232,18 +211,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -272,11 +241,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_endpoint( - self, - ) -> Callable[ - [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] - ]: + def create_endpoint(self) -> Callable[ + [endpoint_service.CreateEndpointRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -291,18 +258,18 @@ def create_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_endpoint" not in self._stubs: - self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint", + if 'create_endpoint' not in self._stubs: + self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint', request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_endpoint"] + return self._stubs['create_endpoint'] @property - def get_endpoint( - self, - ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: + def get_endpoint(self) -> Callable[ + [endpoint_service.GetEndpointRequest], + Awaitable[endpoint.Endpoint]]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -317,21 +284,18 @@ def get_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_endpoint" not in self._stubs: - self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/GetEndpoint", + if 'get_endpoint' not in self._stubs: + self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/GetEndpoint', request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs["get_endpoint"] + return self._stubs['get_endpoint'] @property - def list_endpoints( - self, - ) -> Callable[ - [endpoint_service.ListEndpointsRequest], - Awaitable[endpoint_service.ListEndpointsResponse], - ]: + def list_endpoints(self) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse]]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -346,20 +310,18 @@ def list_endpoints( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_endpoints" not in self._stubs: - self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/ListEndpoints", + if 'list_endpoints' not in self._stubs: + self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/ListEndpoints', request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs["list_endpoints"] + return self._stubs['list_endpoints'] @property - def update_endpoint( - self, - ) -> Callable[ - [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] - ]: + def update_endpoint(self) -> Callable[ + [endpoint_service.UpdateEndpointRequest], + Awaitable[gca_endpoint.Endpoint]]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -374,20 +336,18 @@ def update_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_endpoint" not in self._stubs: - self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint", + if 'update_endpoint' not in self._stubs: + self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint', request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs["update_endpoint"] + return self._stubs['update_endpoint'] @property - def delete_endpoint( - self, - ) -> Callable[ - [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] - ]: + def delete_endpoint(self) -> Callable[ + [endpoint_service.DeleteEndpointRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -402,20 +362,18 @@ def delete_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_endpoint" not in self._stubs: - self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint", + if 'delete_endpoint' not in self._stubs: + self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint', request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_endpoint"] + return self._stubs['delete_endpoint'] @property - def deploy_model( - self, - ) -> Callable[ - [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] - ]: + def deploy_model(self) -> Callable[ + [endpoint_service.DeployModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -431,20 +389,18 @@ def deploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "deploy_model" not in self._stubs: - self._stubs["deploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/DeployModel", + if 'deploy_model' not in self._stubs: + self._stubs['deploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/DeployModel', request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["deploy_model"] + return self._stubs['deploy_model'] @property - def undeploy_model( - self, - ) -> Callable[ - [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] - ]: + def undeploy_model(self) -> Callable[ + [endpoint_service.UndeployModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -461,13 +417,15 @@ def undeploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "undeploy_model" not in self._stubs: - self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.EndpointService/UndeployModel", + if 'undeploy_model' not in self._stubs: + self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.EndpointService/UndeployModel', request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["undeploy_model"] + return self._stubs['undeploy_model'] -__all__ = ("EndpointServiceGrpcAsyncIOTransport",) +__all__ = ( + 'EndpointServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/services/job_service/__init__.py b/google/cloud/aiplatform_v1/services/job_service/__init__.py index 5f157047f5..037407b714 100644 --- a/google/cloud/aiplatform_v1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/job_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import JobServiceAsyncClient __all__ = ( - "JobServiceClient", - "JobServiceAsyncClient", + 'JobServiceClient', + 'JobServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index e253bcc5d6..e76498a85d 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -21,20 +21,18 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1.types import completion_stats from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job @@ -42,9 +40,7 @@ from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1.types import job_service from google.cloud.aiplatform_v1.types import job_state from google.cloud.aiplatform_v1.types import machine_resources @@ -71,50 +67,34 @@ class JobServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) - parse_batch_prediction_job_path = staticmethod( - JobServiceClient.parse_batch_prediction_job_path - ) + parse_batch_prediction_job_path = staticmethod(JobServiceClient.parse_batch_prediction_job_path) custom_job_path = staticmethod(JobServiceClient.custom_job_path) parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) - parse_data_labeling_job_path = staticmethod( - JobServiceClient.parse_data_labeling_job_path - ) + parse_data_labeling_job_path = staticmethod(JobServiceClient.parse_data_labeling_job_path) dataset_path = staticmethod(JobServiceClient.dataset_path) parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) - hyperparameter_tuning_job_path = staticmethod( - JobServiceClient.hyperparameter_tuning_job_path - ) - parse_hyperparameter_tuning_job_path = staticmethod( - JobServiceClient.parse_hyperparameter_tuning_job_path - ) + hyperparameter_tuning_job_path = staticmethod(JobServiceClient.hyperparameter_tuning_job_path) + parse_hyperparameter_tuning_job_path = staticmethod(JobServiceClient.parse_hyperparameter_tuning_job_path) model_path = staticmethod(JobServiceClient.model_path) parse_model_path = staticmethod(JobServiceClient.parse_model_path) trial_path = staticmethod(JobServiceClient.trial_path) parse_trial_path = staticmethod(JobServiceClient.parse_trial_path) - common_billing_account_path = staticmethod( - JobServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - JobServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(JobServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(JobServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(JobServiceClient.common_folder_path) parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) common_organization_path = staticmethod(JobServiceClient.common_organization_path) - parse_common_organization_path = staticmethod( - JobServiceClient.parse_common_organization_path - ) + parse_common_organization_path = staticmethod(JobServiceClient.parse_common_organization_path) common_project_path = staticmethod(JobServiceClient.common_project_path) parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) common_location_path = staticmethod(JobServiceClient.common_location_path) - parse_common_location_path = staticmethod( - JobServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(JobServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -157,18 +137,14 @@ def transport(self) -> JobServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(JobServiceClient).get_transport_class, type(JobServiceClient) - ) + get_transport_class = functools.partial(type(JobServiceClient).get_transport_class, type(JobServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -207,18 +183,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_custom_job( - self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + async def create_custom_job(self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -263,10 +239,8 @@ async def create_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateCustomJobRequest(request) @@ -289,24 +263,30 @@ async def create_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_custom_job( - self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + async def get_custom_job(self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -344,10 +324,8 @@ async def get_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetCustomJobRequest(request) @@ -368,24 +346,30 @@ async def get_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_custom_jobs( - self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsAsyncPager: + async def list_custom_jobs(self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: r"""Lists CustomJobs in a Location. Args: @@ -421,10 +405,8 @@ async def list_custom_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListCustomJobsRequest(request) @@ -445,30 +427,39 @@ async def list_custom_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListCustomJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_custom_job( - self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_custom_job(self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a CustomJob. Args: @@ -514,10 +505,8 @@ async def delete_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteCustomJobRequest(request) @@ -538,11 +527,18 @@ async def delete_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -555,15 +551,14 @@ async def delete_custom_job( # Done; return the response. return response - async def cancel_custom_job( - self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_custom_job(self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -601,10 +596,8 @@ async def cancel_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelCustomJobRequest(request) @@ -625,24 +618,28 @@ async def cancel_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_data_labeling_job( - self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_data_labeling_job(self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -682,10 +679,8 @@ async def create_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateDataLabelingJobRequest(request) @@ -708,24 +703,30 @@ async def create_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_data_labeling_job( - self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + async def get_data_labeling_job(self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -759,10 +760,8 @@ async def get_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetDataLabelingJobRequest(request) @@ -783,24 +782,30 @@ async def get_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_data_labeling_jobs( - self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsAsyncPager: + async def list_data_labeling_jobs(self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -835,10 +840,8 @@ async def list_data_labeling_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListDataLabelingJobsRequest(request) @@ -859,30 +862,39 @@ async def list_data_labeling_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataLabelingJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_data_labeling_job( - self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_data_labeling_job(self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a DataLabelingJob. Args: @@ -929,10 +941,8 @@ async def delete_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteDataLabelingJobRequest(request) @@ -953,11 +963,18 @@ async def delete_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -970,15 +987,14 @@ async def delete_data_labeling_job( # Done; return the response. return response - async def cancel_data_labeling_job( - self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_data_labeling_job(self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1006,10 +1022,8 @@ async def cancel_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelDataLabelingJobRequest(request) @@ -1030,24 +1044,28 @@ async def cancel_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_hyperparameter_tuning_job( - self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_hyperparameter_tuning_job(self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1089,10 +1107,8 @@ async def create_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateHyperparameterTuningJobRequest(request) @@ -1115,24 +1131,30 @@ async def create_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_hyperparameter_tuning_job( - self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + async def get_hyperparameter_tuning_job(self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1168,10 +1190,8 @@ async def get_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetHyperparameterTuningJobRequest(request) @@ -1192,24 +1212,30 @@ async def get_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_hyperparameter_tuning_jobs( - self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + async def list_hyperparameter_tuning_jobs(self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1245,10 +1271,8 @@ async def list_hyperparameter_tuning_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListHyperparameterTuningJobsRequest(request) @@ -1269,30 +1293,39 @@ async def list_hyperparameter_tuning_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListHyperparameterTuningJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_hyperparameter_tuning_job( - self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_hyperparameter_tuning_job(self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1339,10 +1372,8 @@ async def delete_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteHyperparameterTuningJobRequest(request) @@ -1363,11 +1394,18 @@ async def delete_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1380,15 +1418,14 @@ async def delete_hyperparameter_tuning_job( # Done; return the response. return response - async def cancel_hyperparameter_tuning_job( - self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_hyperparameter_tuning_job(self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1429,10 +1466,8 @@ async def cancel_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelHyperparameterTuningJobRequest(request) @@ -1453,24 +1488,28 @@ async def cancel_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_batch_prediction_job( - self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_batch_prediction_job(self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1515,10 +1554,8 @@ async def create_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateBatchPredictionJobRequest(request) @@ -1541,24 +1578,30 @@ async def create_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_batch_prediction_job( - self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + async def get_batch_prediction_job(self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1596,10 +1639,8 @@ async def get_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetBatchPredictionJobRequest(request) @@ -1620,24 +1661,30 @@ async def get_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_batch_prediction_jobs( - self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsAsyncPager: + async def list_batch_prediction_jobs(self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1673,10 +1720,8 @@ async def list_batch_prediction_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListBatchPredictionJobsRequest(request) @@ -1697,30 +1742,39 @@ async def list_batch_prediction_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListBatchPredictionJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_batch_prediction_job( - self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_batch_prediction_job(self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -1768,10 +1822,8 @@ async def delete_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteBatchPredictionJobRequest(request) @@ -1792,11 +1844,18 @@ async def delete_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1809,15 +1868,14 @@ async def delete_batch_prediction_job( # Done; return the response. return response - async def cancel_batch_prediction_job( - self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_batch_prediction_job(self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -1856,10 +1914,8 @@ async def cancel_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelBatchPredictionJobRequest(request) @@ -1880,23 +1936,35 @@ async def cancel_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("JobServiceAsyncClient",) +__all__ = ( + 'JobServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index 746ce91c4b..1a304de108 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -23,22 +23,20 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1.types import completion_stats from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job @@ -46,9 +44,7 @@ from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1.types import job_service from google.cloud.aiplatform_v1.types import job_state from google.cloud.aiplatform_v1.types import machine_resources @@ -73,12 +69,13 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry["grpc"] = JobServiceGrpcTransport - _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport + _transport_registry['grpc'] = JobServiceGrpcTransport + _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -129,7 +126,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -164,8 +161,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: JobServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -180,194 +178,143 @@ def transport(self) -> JobServiceTransport: return self._transport @staticmethod - def batch_prediction_job_path( - project: str, location: str, batch_prediction_job: str, - ) -> str: + def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, - location=location, - batch_prediction_job=batch_prediction_job, - ) + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: + def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: """Parse a batch_prediction_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def custom_job_path(project: str, location: str, custom_job: str,) -> str: + def custom_job_path(project: str,location: str,custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str, str]: + def parse_custom_job_path(path: str) -> Dict[str,str]: """Parse a custom_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def data_labeling_job_path( - project: str, location: str, data_labeling_job: str, - ) -> str: + def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str, str]: + def parse_data_labeling_job_path(path: str) -> Dict[str,str]: """Parse a data_labeling_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path( - project: str, location: str, hyperparameter_tuning_job: str, - ) -> str: + def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str, location: str, study: str, trial: str,) -> str: + def trial_path(project: str,location: str,study: str,trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( - project=project, location=location, study=study, trial=trial, - ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) @staticmethod - def parse_trial_path(path: str) -> Dict[str, str]: + def parse_trial_path(path: str) -> Dict[str,str]: """Parse a trial path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -411,9 +358,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -423,9 +368,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -437,9 +380,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -451,10 +392,8 @@ def __init__( if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -473,16 +412,15 @@ def __init__( client_info=client_info, ) - def create_custom_job( - self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job(self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -527,10 +465,8 @@ def create_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateCustomJobRequest. @@ -554,24 +490,30 @@ def create_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_custom_job( - self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job(self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -609,10 +551,8 @@ def get_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetCustomJobRequest. @@ -634,24 +574,30 @@ def get_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_custom_jobs( - self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs(self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -687,10 +633,8 @@ def list_custom_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListCustomJobsRequest. @@ -712,30 +656,39 @@ def list_custom_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_custom_job( - self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_custom_job(self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a CustomJob. Args: @@ -781,10 +734,8 @@ def delete_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteCustomJobRequest. @@ -806,11 +757,18 @@ def delete_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -823,15 +781,14 @@ def delete_custom_job( # Done; return the response. return response - def cancel_custom_job( - self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job(self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -869,10 +826,8 @@ def cancel_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelCustomJobRequest. @@ -894,24 +849,28 @@ def cancel_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def create_data_labeling_job( - self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + def create_data_labeling_job(self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -951,10 +910,8 @@ def create_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateDataLabelingJobRequest. @@ -978,24 +935,30 @@ def create_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_data_labeling_job( - self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job(self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -1029,10 +992,8 @@ def get_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetDataLabelingJobRequest. @@ -1054,24 +1015,30 @@ def get_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_data_labeling_jobs( - self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs(self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -1106,10 +1073,8 @@ def list_data_labeling_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListDataLabelingJobsRequest. @@ -1131,30 +1096,39 @@ def list_data_labeling_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job( - self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_data_labeling_job(self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1201,10 +1175,8 @@ def delete_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteDataLabelingJobRequest. @@ -1226,11 +1198,18 @@ def delete_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1243,15 +1222,14 @@ def delete_data_labeling_job( # Done; return the response. return response - def cancel_data_labeling_job( - self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job(self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1279,10 +1257,8 @@ def cancel_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelDataLabelingJobRequest. @@ -1304,24 +1280,28 @@ def cancel_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def create_hyperparameter_tuning_job( - self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + def create_hyperparameter_tuning_job(self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1363,10 +1343,8 @@ def create_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateHyperparameterTuningJobRequest. @@ -1385,31 +1363,35 @@ def create_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.create_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_hyperparameter_tuning_job( - self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job(self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1445,10 +1427,8 @@ def get_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetHyperparameterTuningJobRequest. @@ -1465,31 +1445,35 @@ def get_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.get_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_hyperparameter_tuning_jobs( - self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs(self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1525,10 +1509,8 @@ def list_hyperparameter_tuning_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListHyperparameterTuningJobsRequest. @@ -1545,37 +1527,44 @@ def list_hyperparameter_tuning_jobs( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_hyperparameter_tuning_jobs - ] + rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job( - self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_hyperparameter_tuning_job(self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1622,10 +1611,8 @@ def delete_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteHyperparameterTuningJobRequest. @@ -1642,18 +1629,23 @@ def delete_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.delete_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1666,15 +1658,14 @@ def delete_hyperparameter_tuning_job( # Done; return the response. return response - def cancel_hyperparameter_tuning_job( - self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job(self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1715,10 +1706,8 @@ def cancel_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelHyperparameterTuningJobRequest. @@ -1735,31 +1724,33 @@ def cancel_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.cancel_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def create_batch_prediction_job( - self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + def create_batch_prediction_job(self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1804,10 +1795,8 @@ def create_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateBatchPredictionJobRequest. @@ -1826,31 +1815,35 @@ def create_batch_prediction_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.create_batch_prediction_job - ] + rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_batch_prediction_job( - self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job(self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1888,10 +1881,8 @@ def get_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetBatchPredictionJobRequest. @@ -1913,24 +1904,30 @@ def get_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_batch_prediction_jobs( - self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs(self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1966,10 +1963,8 @@ def list_batch_prediction_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListBatchPredictionJobsRequest. @@ -1986,37 +1981,44 @@ def list_batch_prediction_jobs( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_batch_prediction_jobs - ] + rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job( - self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_batch_prediction_job(self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2064,10 +2066,8 @@ def delete_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteBatchPredictionJobRequest. @@ -2084,18 +2084,23 @@ def delete_batch_prediction_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.delete_batch_prediction_job - ] + rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -2108,15 +2113,14 @@ def delete_batch_prediction_job( # Done; return the response. return response - def cancel_batch_prediction_job( - self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job(self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -2155,10 +2159,8 @@ def cancel_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelBatchPredictionJobRequest. @@ -2175,30 +2177,40 @@ def cancel_batch_prediction_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.cancel_batch_prediction_job - ] + rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("JobServiceClient",) +__all__ = ( + 'JobServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/job_service/pagers.py b/google/cloud/aiplatform_v1/services/job_service/pagers.py index 35d679b6ad..dfc5e30105 100644 --- a/google/cloud/aiplatform_v1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/job_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1.types import batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job @@ -50,15 +41,12 @@ class ListCustomJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListCustomJobsResponse], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListCustomJobsResponse], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -92,7 +80,7 @@ def __iter__(self) -> Iterable[custom_job.CustomJob]: yield from page.custom_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListCustomJobsAsyncPager: @@ -112,15 +100,12 @@ class ListCustomJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -158,7 +143,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataLabelingJobsPager: @@ -178,15 +163,12 @@ class ListDataLabelingJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListDataLabelingJobsResponse], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListDataLabelingJobsResponse], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -220,7 +202,7 @@ def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: yield from page.data_labeling_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataLabelingJobsAsyncPager: @@ -240,15 +222,12 @@ class ListDataLabelingJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -286,7 +265,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsPager: @@ -306,15 +285,12 @@ class ListHyperparameterTuningJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -348,7 +324,7 @@ def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob yield from page.hyperparameter_tuning_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsAsyncPager: @@ -368,17 +344,12 @@ class ListHyperparameterTuningJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] - ], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListHyperparameterTuningJobsResponse]], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -400,18 +371,14 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + async def pages(self) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__( - self, - ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + def __aiter__(self) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: async def async_generator(): async for page in self.pages: for response in page.hyperparameter_tuning_jobs: @@ -420,7 +387,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListBatchPredictionJobsPager: @@ -440,15 +407,12 @@ class ListBatchPredictionJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListBatchPredictionJobsResponse], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListBatchPredictionJobsResponse], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -482,7 +446,7 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: yield from page.batch_prediction_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListBatchPredictionJobsAsyncPager: @@ -502,15 +466,12 @@ class ListBatchPredictionJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -548,4 +509,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py index 349bfbcdea..8b5de46a7e 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] -_transport_registry["grpc"] = JobServiceGrpcTransport -_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = JobServiceGrpcTransport +_transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport __all__ = ( - "JobServiceTransport", - "JobServiceGrpcTransport", - "JobServiceGrpcAsyncIOTransport", + 'JobServiceTransport', + 'JobServiceGrpcTransport', + 'JobServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1/services/job_service/transports/base.py index 42ab8e1688..f3ee6dc74a 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/base.py @@ -21,23 +21,19 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -46,29 +42,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -84,57 +80,65 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_custom_job: gapic_v1.method.wrap_method( - self.create_custom_job, default_timeout=None, client_info=client_info, + self.create_custom_job, + default_timeout=None, + client_info=client_info, ), self.get_custom_job: gapic_v1.method.wrap_method( - self.get_custom_job, default_timeout=None, client_info=client_info, + self.get_custom_job, + default_timeout=None, + client_info=client_info, ), self.list_custom_jobs: gapic_v1.method.wrap_method( - self.list_custom_jobs, default_timeout=None, client_info=client_info, + self.list_custom_jobs, + default_timeout=None, + client_info=client_info, ), self.delete_custom_job: gapic_v1.method.wrap_method( - self.delete_custom_job, default_timeout=None, client_info=client_info, + self.delete_custom_job, + default_timeout=None, + client_info=client_info, ), self.cancel_custom_job: gapic_v1.method.wrap_method( - self.cancel_custom_job, default_timeout=None, client_info=client_info, + self.cancel_custom_job, + default_timeout=None, + client_info=client_info, ), self.create_data_labeling_job: gapic_v1.method.wrap_method( self.create_data_labeling_job, @@ -211,6 +215,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + } @property @@ -219,216 +224,186 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_custom_job( - self, - ) -> typing.Callable[ - [job_service.CreateCustomJobRequest], - typing.Union[ - gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] - ], - ]: + def create_custom_job(self) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, + typing.Awaitable[gca_custom_job.CustomJob] + ]]: raise NotImplementedError() @property - def get_custom_job( - self, - ) -> typing.Callable[ - [job_service.GetCustomJobRequest], - typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], - ]: + def get_custom_job(self) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[ + custom_job.CustomJob, + typing.Awaitable[custom_job.CustomJob] + ]]: raise NotImplementedError() @property - def list_custom_jobs( - self, - ) -> typing.Callable[ - [job_service.ListCustomJobsRequest], - typing.Union[ - job_service.ListCustomJobsResponse, - typing.Awaitable[job_service.ListCustomJobsResponse], - ], - ]: + def list_custom_jobs(self) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse] + ]]: raise NotImplementedError() @property - def delete_custom_job( - self, - ) -> typing.Callable[ - [job_service.DeleteCustomJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_custom_job(self) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_custom_job( - self, - ) -> typing.Callable[ - [job_service.CancelCustomJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_custom_job(self) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def create_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.CreateDataLabelingJobRequest], - typing.Union[ - gca_data_labeling_job.DataLabelingJob, - typing.Awaitable[gca_data_labeling_job.DataLabelingJob], - ], - ]: + def create_data_labeling_job(self) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob] + ]]: raise NotImplementedError() @property - def get_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], - typing.Union[ - data_labeling_job.DataLabelingJob, - typing.Awaitable[data_labeling_job.DataLabelingJob], - ], - ]: + def get_data_labeling_job(self) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob] + ]]: raise NotImplementedError() @property - def list_data_labeling_jobs( - self, - ) -> typing.Callable[ - [job_service.ListDataLabelingJobsRequest], - typing.Union[ - job_service.ListDataLabelingJobsResponse, - typing.Awaitable[job_service.ListDataLabelingJobsResponse], - ], - ]: + def list_data_labeling_jobs(self) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse] + ]]: raise NotImplementedError() @property - def delete_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_data_labeling_job(self) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.CancelDataLabelingJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_data_labeling_job(self) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def create_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - typing.Union[ - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], - ], - ]: + def create_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob] + ]]: raise NotImplementedError() @property - def get_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.GetHyperparameterTuningJobRequest], - typing.Union[ - hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], - ], - ]: + def get_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob] + ]]: raise NotImplementedError() @property - def list_hyperparameter_tuning_jobs( - self, - ) -> typing.Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - typing.Union[ - job_service.ListHyperparameterTuningJobsResponse, - typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], - ], - ]: + def list_hyperparameter_tuning_jobs(self) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ]]: raise NotImplementedError() @property - def delete_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def create_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.CreateBatchPredictionJobRequest], - typing.Union[ - gca_batch_prediction_job.BatchPredictionJob, - typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], - ], - ]: + def create_batch_prediction_job(self) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob] + ]]: raise NotImplementedError() @property - def get_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.GetBatchPredictionJobRequest], - typing.Union[ - batch_prediction_job.BatchPredictionJob, - typing.Awaitable[batch_prediction_job.BatchPredictionJob], - ], - ]: + def get_batch_prediction_job(self) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob] + ]]: raise NotImplementedError() @property - def list_batch_prediction_jobs( - self, - ) -> typing.Callable[ - [job_service.ListBatchPredictionJobsRequest], - typing.Union[ - job_service.ListBatchPredictionJobsResponse, - typing.Awaitable[job_service.ListBatchPredictionJobsResponse], - ], - ]: + def list_batch_prediction_jobs(self) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse] + ]]: raise NotImplementedError() @property - def delete_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_batch_prediction_job(self) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.CancelBatchPredictionJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_batch_prediction_job(self) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() -__all__ = ("JobServiceTransport",) +__all__ = ( + 'JobServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py index a9c90ecdaa..9a88545dd8 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py @@ -18,27 +18,23 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -58,24 +54,21 @@ class JobServiceGrpcTransport(JobServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -121,7 +114,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -129,70 +125,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -200,32 +176,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -255,12 +219,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -272,15 +237,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_custom_job( - self, - ) -> Callable[[job_service.CreateCustomJobRequest], gca_custom_job.CustomJob]: + def create_custom_job(self) -> Callable[ + [job_service.CreateCustomJobRequest], + gca_custom_job.CustomJob]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -296,18 +263,18 @@ def create_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_custom_job" not in self._stubs: - self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateCustomJob", + if 'create_custom_job' not in self._stubs: + self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateCustomJob', request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs["create_custom_job"] + return self._stubs['create_custom_job'] @property - def get_custom_job( - self, - ) -> Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: + def get_custom_job(self) -> Callable[ + [job_service.GetCustomJobRequest], + custom_job.CustomJob]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -322,20 +289,18 @@ def get_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_custom_job" not in self._stubs: - self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetCustomJob", + if 'get_custom_job' not in self._stubs: + self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetCustomJob', request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs["get_custom_job"] + return self._stubs['get_custom_job'] @property - def list_custom_jobs( - self, - ) -> Callable[ - [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse - ]: + def list_custom_jobs(self) -> Callable[ + [job_service.ListCustomJobsRequest], + job_service.ListCustomJobsResponse]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -350,18 +315,18 @@ def list_custom_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_custom_jobs" not in self._stubs: - self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListCustomJobs", + if 'list_custom_jobs' not in self._stubs: + self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListCustomJobs', request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs["list_custom_jobs"] + return self._stubs['list_custom_jobs'] @property - def delete_custom_job( - self, - ) -> Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: + def delete_custom_job(self) -> Callable[ + [job_service.DeleteCustomJobRequest], + operations.Operation]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -376,18 +341,18 @@ def delete_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_custom_job" not in self._stubs: - self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteCustomJob", + if 'delete_custom_job' not in self._stubs: + self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteCustomJob', request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_custom_job"] + return self._stubs['delete_custom_job'] @property - def cancel_custom_job( - self, - ) -> Callable[[job_service.CancelCustomJobRequest], empty.Empty]: + def cancel_custom_job(self) -> Callable[ + [job_service.CancelCustomJobRequest], + empty.Empty]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -414,21 +379,18 @@ def cancel_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_custom_job" not in self._stubs: - self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelCustomJob", + if 'cancel_custom_job' not in self._stubs: + self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelCustomJob', request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_custom_job"] + return self._stubs['cancel_custom_job'] @property - def create_data_labeling_job( - self, - ) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob, - ]: + def create_data_labeling_job(self) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + gca_data_labeling_job.DataLabelingJob]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -443,20 +405,18 @@ def create_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_data_labeling_job" not in self._stubs: - self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob", + if 'create_data_labeling_job' not in self._stubs: + self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob', request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["create_data_labeling_job"] + return self._stubs['create_data_labeling_job'] @property - def get_data_labeling_job( - self, - ) -> Callable[ - [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob - ]: + def get_data_labeling_job(self) -> Callable[ + [job_service.GetDataLabelingJobRequest], + data_labeling_job.DataLabelingJob]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -471,21 +431,18 @@ def get_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_data_labeling_job" not in self._stubs: - self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob", + if 'get_data_labeling_job' not in self._stubs: + self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob', request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["get_data_labeling_job"] + return self._stubs['get_data_labeling_job'] @property - def list_data_labeling_jobs( - self, - ) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse, - ]: + def list_data_labeling_jobs(self) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -500,18 +457,18 @@ def list_data_labeling_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_labeling_jobs" not in self._stubs: - self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs", + if 'list_data_labeling_jobs' not in self._stubs: + self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs', request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs["list_data_labeling_jobs"] + return self._stubs['list_data_labeling_jobs'] @property - def delete_data_labeling_job( - self, - ) -> Callable[[job_service.DeleteDataLabelingJobRequest], operations.Operation]: + def delete_data_labeling_job(self) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], + operations.Operation]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -526,18 +483,18 @@ def delete_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_data_labeling_job" not in self._stubs: - self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob", + if 'delete_data_labeling_job' not in self._stubs: + self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob', request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_data_labeling_job"] + return self._stubs['delete_data_labeling_job'] @property - def cancel_data_labeling_job( - self, - ) -> Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: + def cancel_data_labeling_job(self) -> Callable[ + [job_service.CancelDataLabelingJobRequest], + empty.Empty]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -553,21 +510,18 @@ def cancel_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_data_labeling_job" not in self._stubs: - self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob", + if 'cancel_data_labeling_job' not in self._stubs: + self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob', request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_data_labeling_job"] + return self._stubs['cancel_data_labeling_job'] @property - def create_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - ]: + def create_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + gca_hyperparameter_tuning_job.HyperparameterTuningJob]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -583,23 +537,18 @@ def create_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "create_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob", + if 'create_hyperparameter_tuning_job' not in self._stubs: + self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob', request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["create_hyperparameter_tuning_job"] + return self._stubs['create_hyperparameter_tuning_job'] @property - def get_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob, - ]: + def get_hyperparameter_tuning_job(self) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + hyperparameter_tuning_job.HyperparameterTuningJob]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -614,23 +563,18 @@ def get_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "get_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob", + if 'get_hyperparameter_tuning_job' not in self._stubs: + self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob', request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["get_hyperparameter_tuning_job"] + return self._stubs['get_hyperparameter_tuning_job'] @property - def list_hyperparameter_tuning_jobs( - self, - ) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse, - ]: + def list_hyperparameter_tuning_jobs(self) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -646,22 +590,18 @@ def list_hyperparameter_tuning_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_hyperparameter_tuning_jobs" not in self._stubs: - self._stubs[ - "list_hyperparameter_tuning_jobs" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs", + if 'list_hyperparameter_tuning_jobs' not in self._stubs: + self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs', request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs["list_hyperparameter_tuning_jobs"] + return self._stubs['list_hyperparameter_tuning_jobs'] @property - def delete_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation - ]: + def delete_hyperparameter_tuning_job(self) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + operations.Operation]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -677,20 +617,18 @@ def delete_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "delete_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob", + if 'delete_hyperparameter_tuning_job' not in self._stubs: + self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob', request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_hyperparameter_tuning_job"] + return self._stubs['delete_hyperparameter_tuning_job'] @property - def cancel_hyperparameter_tuning_job( - self, - ) -> Callable[[job_service.CancelHyperparameterTuningJobRequest], empty.Empty]: + def cancel_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + empty.Empty]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -719,23 +657,18 @@ def cancel_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "cancel_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob", + if 'cancel_hyperparameter_tuning_job' not in self._stubs: + self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob', request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_hyperparameter_tuning_job"] + return self._stubs['cancel_hyperparameter_tuning_job'] @property - def create_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob, - ]: + def create_batch_prediction_job(self) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + gca_batch_prediction_job.BatchPredictionJob]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -751,21 +684,18 @@ def create_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_batch_prediction_job" not in self._stubs: - self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob", + if 'create_batch_prediction_job' not in self._stubs: + self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob', request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["create_batch_prediction_job"] + return self._stubs['create_batch_prediction_job'] @property - def get_batch_prediction_job( - self, - ) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob, - ]: + def get_batch_prediction_job(self) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + batch_prediction_job.BatchPredictionJob]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -780,21 +710,18 @@ def get_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_batch_prediction_job" not in self._stubs: - self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob", + if 'get_batch_prediction_job' not in self._stubs: + self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob', request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["get_batch_prediction_job"] + return self._stubs['get_batch_prediction_job'] @property - def list_batch_prediction_jobs( - self, - ) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse, - ]: + def list_batch_prediction_jobs(self) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -809,18 +736,18 @@ def list_batch_prediction_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_batch_prediction_jobs" not in self._stubs: - self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs", + if 'list_batch_prediction_jobs' not in self._stubs: + self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs', request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs["list_batch_prediction_jobs"] + return self._stubs['list_batch_prediction_jobs'] @property - def delete_batch_prediction_job( - self, - ) -> Callable[[job_service.DeleteBatchPredictionJobRequest], operations.Operation]: + def delete_batch_prediction_job(self) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], + operations.Operation]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -836,18 +763,18 @@ def delete_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_batch_prediction_job" not in self._stubs: - self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob", + if 'delete_batch_prediction_job' not in self._stubs: + self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob', request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_batch_prediction_job"] + return self._stubs['delete_batch_prediction_job'] @property - def cancel_batch_prediction_job( - self, - ) -> Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: + def cancel_batch_prediction_job(self) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], + empty.Empty]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -873,13 +800,15 @@ def cancel_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_batch_prediction_job" not in self._stubs: - self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob", + if 'cancel_batch_prediction_job' not in self._stubs: + self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob', request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_batch_prediction_job"] + return self._stubs['cancel_batch_prediction_job'] -__all__ = ("JobServiceGrpcTransport",) +__all__ = ( + 'JobServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py index f056094c9d..2ce9fb52e0 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py @@ -18,28 +18,24 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -65,18 +61,16 @@ class JobServiceGrpcAsyncIOTransport(JobServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -102,24 +96,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -154,10 +146,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -166,7 +158,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -174,70 +169,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -245,18 +220,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -285,11 +250,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_custom_job( - self, - ) -> Callable[ - [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] - ]: + def create_custom_job(self) -> Callable[ + [job_service.CreateCustomJobRequest], + Awaitable[gca_custom_job.CustomJob]]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -305,18 +268,18 @@ def create_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_custom_job" not in self._stubs: - self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateCustomJob", + if 'create_custom_job' not in self._stubs: + self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateCustomJob', request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs["create_custom_job"] + return self._stubs['create_custom_job'] @property - def get_custom_job( - self, - ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: + def get_custom_job(self) -> Callable[ + [job_service.GetCustomJobRequest], + Awaitable[custom_job.CustomJob]]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -331,21 +294,18 @@ def get_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_custom_job" not in self._stubs: - self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetCustomJob", + if 'get_custom_job' not in self._stubs: + self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetCustomJob', request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs["get_custom_job"] + return self._stubs['get_custom_job'] @property - def list_custom_jobs( - self, - ) -> Callable[ - [job_service.ListCustomJobsRequest], - Awaitable[job_service.ListCustomJobsResponse], - ]: + def list_custom_jobs(self) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse]]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -360,20 +320,18 @@ def list_custom_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_custom_jobs" not in self._stubs: - self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListCustomJobs", + if 'list_custom_jobs' not in self._stubs: + self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListCustomJobs', request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs["list_custom_jobs"] + return self._stubs['list_custom_jobs'] @property - def delete_custom_job( - self, - ) -> Callable[ - [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] - ]: + def delete_custom_job(self) -> Callable[ + [job_service.DeleteCustomJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -388,18 +346,18 @@ def delete_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_custom_job" not in self._stubs: - self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteCustomJob", + if 'delete_custom_job' not in self._stubs: + self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteCustomJob', request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_custom_job"] + return self._stubs['delete_custom_job'] @property - def cancel_custom_job( - self, - ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: + def cancel_custom_job(self) -> Callable[ + [job_service.CancelCustomJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -426,21 +384,18 @@ def cancel_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_custom_job" not in self._stubs: - self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelCustomJob", + if 'cancel_custom_job' not in self._stubs: + self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelCustomJob', request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_custom_job"] + return self._stubs['cancel_custom_job'] @property - def create_data_labeling_job( - self, - ) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - Awaitable[gca_data_labeling_job.DataLabelingJob], - ]: + def create_data_labeling_job(self) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob]]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -455,21 +410,18 @@ def create_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_data_labeling_job" not in self._stubs: - self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob", + if 'create_data_labeling_job' not in self._stubs: + self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob', request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["create_data_labeling_job"] + return self._stubs['create_data_labeling_job'] @property - def get_data_labeling_job( - self, - ) -> Callable[ - [job_service.GetDataLabelingJobRequest], - Awaitable[data_labeling_job.DataLabelingJob], - ]: + def get_data_labeling_job(self) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob]]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -484,21 +436,18 @@ def get_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_data_labeling_job" not in self._stubs: - self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob", + if 'get_data_labeling_job' not in self._stubs: + self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob', request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["get_data_labeling_job"] + return self._stubs['get_data_labeling_job'] @property - def list_data_labeling_jobs( - self, - ) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - Awaitable[job_service.ListDataLabelingJobsResponse], - ]: + def list_data_labeling_jobs(self) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse]]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -513,20 +462,18 @@ def list_data_labeling_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_labeling_jobs" not in self._stubs: - self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs", + if 'list_data_labeling_jobs' not in self._stubs: + self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs', request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs["list_data_labeling_jobs"] + return self._stubs['list_data_labeling_jobs'] @property - def delete_data_labeling_job( - self, - ) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] - ]: + def delete_data_labeling_job(self) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -541,18 +488,18 @@ def delete_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_data_labeling_job" not in self._stubs: - self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob", + if 'delete_data_labeling_job' not in self._stubs: + self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob', request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_data_labeling_job"] + return self._stubs['delete_data_labeling_job'] @property - def cancel_data_labeling_job( - self, - ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: + def cancel_data_labeling_job(self) -> Callable[ + [job_service.CancelDataLabelingJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -568,21 +515,18 @@ def cancel_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_data_labeling_job" not in self._stubs: - self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob", + if 'cancel_data_labeling_job' not in self._stubs: + self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob', request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_data_labeling_job"] + return self._stubs['cancel_data_labeling_job'] @property - def create_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], - ]: + def create_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob]]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -598,23 +542,18 @@ def create_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "create_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob", + if 'create_hyperparameter_tuning_job' not in self._stubs: + self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob', request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["create_hyperparameter_tuning_job"] + return self._stubs['create_hyperparameter_tuning_job'] @property - def get_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], - ]: + def get_hyperparameter_tuning_job(self) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob]]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -629,23 +568,18 @@ def get_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "get_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob", + if 'get_hyperparameter_tuning_job' not in self._stubs: + self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob', request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["get_hyperparameter_tuning_job"] + return self._stubs['get_hyperparameter_tuning_job'] @property - def list_hyperparameter_tuning_jobs( - self, - ) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - Awaitable[job_service.ListHyperparameterTuningJobsResponse], - ]: + def list_hyperparameter_tuning_jobs(self) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse]]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -661,23 +595,18 @@ def list_hyperparameter_tuning_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_hyperparameter_tuning_jobs" not in self._stubs: - self._stubs[ - "list_hyperparameter_tuning_jobs" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs", + if 'list_hyperparameter_tuning_jobs' not in self._stubs: + self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs', request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs["list_hyperparameter_tuning_jobs"] + return self._stubs['list_hyperparameter_tuning_jobs'] @property - def delete_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - Awaitable[operations.Operation], - ]: + def delete_hyperparameter_tuning_job(self) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -693,22 +622,18 @@ def delete_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "delete_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob", + if 'delete_hyperparameter_tuning_job' not in self._stubs: + self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob', request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_hyperparameter_tuning_job"] + return self._stubs['delete_hyperparameter_tuning_job'] @property - def cancel_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] - ]: + def cancel_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -737,23 +662,18 @@ def cancel_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "cancel_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob", + if 'cancel_hyperparameter_tuning_job' not in self._stubs: + self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob', request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_hyperparameter_tuning_job"] + return self._stubs['cancel_hyperparameter_tuning_job'] @property - def create_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - Awaitable[gca_batch_prediction_job.BatchPredictionJob], - ]: + def create_batch_prediction_job(self) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob]]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -769,21 +689,18 @@ def create_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_batch_prediction_job" not in self._stubs: - self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob", + if 'create_batch_prediction_job' not in self._stubs: + self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob', request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["create_batch_prediction_job"] + return self._stubs['create_batch_prediction_job'] @property - def get_batch_prediction_job( - self, - ) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - Awaitable[batch_prediction_job.BatchPredictionJob], - ]: + def get_batch_prediction_job(self) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob]]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -798,21 +715,18 @@ def get_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_batch_prediction_job" not in self._stubs: - self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob", + if 'get_batch_prediction_job' not in self._stubs: + self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob', request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["get_batch_prediction_job"] + return self._stubs['get_batch_prediction_job'] @property - def list_batch_prediction_jobs( - self, - ) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - Awaitable[job_service.ListBatchPredictionJobsResponse], - ]: + def list_batch_prediction_jobs(self) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse]]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -827,20 +741,18 @@ def list_batch_prediction_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_batch_prediction_jobs" not in self._stubs: - self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs", + if 'list_batch_prediction_jobs' not in self._stubs: + self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs', request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs["list_batch_prediction_jobs"] + return self._stubs['list_batch_prediction_jobs'] @property - def delete_batch_prediction_job( - self, - ) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] - ]: + def delete_batch_prediction_job(self) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -856,20 +768,18 @@ def delete_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_batch_prediction_job" not in self._stubs: - self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob", + if 'delete_batch_prediction_job' not in self._stubs: + self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob', request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_batch_prediction_job"] + return self._stubs['delete_batch_prediction_job'] @property - def cancel_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] - ]: + def cancel_batch_prediction_job(self) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -895,13 +805,15 @@ def cancel_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_batch_prediction_job" not in self._stubs: - self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob", + if 'cancel_batch_prediction_job' not in self._stubs: + self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob', request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_batch_prediction_job"] + return self._stubs['cancel_batch_prediction_job'] -__all__ = ("JobServiceGrpcAsyncIOTransport",) +__all__ = ( + 'JobServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1/services/migration_service/__init__.py index 1d6216d1f7..c533a12b45 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/migration_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import MigrationServiceAsyncClient __all__ = ( - "MigrationServiceClient", - "MigrationServiceAsyncClient", + 'MigrationServiceClient', + 'MigrationServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1/services/migration_service/async_client.py index e7f45eeaf5..d48eb4ae0b 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -51,9 +51,7 @@ class MigrationServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) - parse_annotated_dataset_path = staticmethod( - MigrationServiceClient.parse_annotated_dataset_path - ) + parse_annotated_dataset_path = staticmethod(MigrationServiceClient.parse_annotated_dataset_path) dataset_path = staticmethod(MigrationServiceClient.dataset_path) parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) dataset_path = staticmethod(MigrationServiceClient.dataset_path) @@ -67,34 +65,20 @@ class MigrationServiceAsyncClient: version_path = staticmethod(MigrationServiceClient.version_path) parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) - common_billing_account_path = staticmethod( - MigrationServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - MigrationServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(MigrationServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(MigrationServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - MigrationServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(MigrationServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - MigrationServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - MigrationServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(MigrationServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(MigrationServiceClient.parse_common_organization_path) common_project_path = staticmethod(MigrationServiceClient.common_project_path) - parse_common_project_path = staticmethod( - MigrationServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(MigrationServiceClient.parse_common_project_path) common_location_path = staticmethod(MigrationServiceClient.common_location_path) - parse_common_location_path = staticmethod( - MigrationServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(MigrationServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -137,18 +121,14 @@ def transport(self) -> MigrationServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient) - ) + get_transport_class = functools.partial(type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -187,17 +167,17 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def search_migratable_resources( - self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesAsyncPager: + async def search_migratable_resources(self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -238,10 +218,8 @@ async def search_migratable_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = migration_service.SearchMigratableResourcesRequest(request) @@ -262,33 +240,40 @@ async def search_migratable_resources( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.SearchMigratableResourcesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def batch_migrate_resources( - self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[ - migration_service.MigrateResourceRequest - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def batch_migrate_resources(self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -337,10 +322,8 @@ async def batch_migrate_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = migration_service.BatchMigrateResourcesRequest(request) @@ -364,11 +347,18 @@ async def batch_migrate_resources( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -382,14 +372,21 @@ async def batch_migrate_resources( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("MigrationServiceAsyncClient",) +__all__ = ( + 'MigrationServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 0a23f262c2..94758701d8 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -50,14 +50,13 @@ class MigrationServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry['grpc'] = MigrationServiceGrpcTransport + _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[MigrationServiceTransport]] - _transport_registry["grpc"] = MigrationServiceGrpcTransport - _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[MigrationServiceTransport]: """Return an appropriate transport class. Args: @@ -111,7 +110,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -146,8 +145,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MigrationServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -162,183 +162,143 @@ def transport(self) -> MigrationServiceTransport: return self._transport @staticmethod - def annotated_dataset_path( - project: str, dataset: str, annotated_dataset: str, - ) -> str: + def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: """Return a fully-qualified annotated_dataset string.""" - return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( - project=project, dataset=dataset, annotated_dataset=annotated_dataset, - ) + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) @staticmethod - def parse_annotated_dataset_path(path: str) -> Dict[str, str]: + def parse_annotated_dataset_path(path: str) -> Dict[str,str]: """Parse a annotated_dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def version_path(project: str, model: str, version: str,) -> str: + def version_path(project: str,model: str,version: str,) -> str: """Return a fully-qualified version string.""" - return "projects/{project}/models/{model}/versions/{version}".format( - project=project, model=model, version=version, - ) + return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) @staticmethod - def parse_version_path(path: str) -> Dict[str, str]: + def parse_version_path(path: str) -> Dict[str,str]: """Parse a version path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -382,9 +342,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -394,9 +352,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -408,9 +364,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -422,10 +376,8 @@ def __init__( if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -444,15 +396,14 @@ def __init__( client_info=client_info, ) - def search_migratable_resources( - self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesPager: + def search_migratable_resources(self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -493,10 +444,8 @@ def search_migratable_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a migration_service.SearchMigratableResourcesRequest. @@ -513,40 +462,45 @@ def search_migratable_resources( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.search_migratable_resources - ] + rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchMigratableResourcesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def batch_migrate_resources( - self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[ - migration_service.MigrateResourceRequest - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def batch_migrate_resources(self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -595,10 +549,8 @@ def batch_migrate_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a migration_service.BatchMigrateResourcesRequest. @@ -622,11 +574,18 @@ def batch_migrate_resources( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation.from_gapic( @@ -640,14 +599,21 @@ def batch_migrate_resources( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("MigrationServiceClient",) +__all__ = ( + 'MigrationServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1/services/migration_service/pagers.py index 02a46451df..08654cbf6e 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/migration_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1.types import migratable_resource from google.cloud.aiplatform_v1.types import migration_service @@ -47,15 +38,12 @@ class SearchMigratableResourcesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., migration_service.SearchMigratableResourcesResponse], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: yield from page.migratable_resources def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class SearchMigratableResourcesAsyncPager: @@ -109,17 +97,12 @@ class SearchMigratableResourcesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[migration_service.SearchMigratableResourcesResponse] - ], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[migration_service.SearchMigratableResourcesResponse]], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -141,9 +124,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + async def pages(self) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -159,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py index 38c72756f6..9fb765fdcc 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] -_transport_registry["grpc"] = MigrationServiceGrpcTransport -_transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = MigrationServiceGrpcTransport +_transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport __all__ = ( - "MigrationServiceTransport", - "MigrationServiceGrpcTransport", - "MigrationServiceGrpcAsyncIOTransport", + 'MigrationServiceTransport', + 'MigrationServiceGrpcTransport', + 'MigrationServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py index da4cabae63..4f31e9b243 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -33,29 +33,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class MigrationServiceTransport(abc.ABC): """Abstract transport class for MigrationService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -71,40 +71,38 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -118,6 +116,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + } @property @@ -126,25 +125,24 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def search_migratable_resources( - self, - ) -> typing.Callable[ - [migration_service.SearchMigratableResourcesRequest], - typing.Union[ - migration_service.SearchMigratableResourcesResponse, - typing.Awaitable[migration_service.SearchMigratableResourcesResponse], - ], - ]: + def search_migratable_resources(self) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse] + ]]: raise NotImplementedError() @property - def batch_migrate_resources( - self, - ) -> typing.Callable[ - [migration_service.BatchMigrateResourcesRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def batch_migrate_resources(self) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() -__all__ = ("MigrationServiceTransport",) +__all__ = ( + 'MigrationServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py index f11d72386d..49659f9b31 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -47,24 +47,21 @@ class MigrationServiceGrpcTransport(MigrationServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -110,7 +107,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -118,70 +118,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -189,32 +169,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -244,12 +212,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -261,18 +230,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def search_migratable_resources( - self, - ) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - migration_service.SearchMigratableResourcesResponse, - ]: + def search_migratable_resources(self) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -290,20 +258,18 @@ def search_migratable_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "search_migratable_resources" not in self._stubs: - self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources", + if 'search_migratable_resources' not in self._stubs: + self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources', request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs["search_migratable_resources"] + return self._stubs['search_migratable_resources'] @property - def batch_migrate_resources( - self, - ) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], operations.Operation - ]: + def batch_migrate_resources(self) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + operations.Operation]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -320,13 +286,15 @@ def batch_migrate_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "batch_migrate_resources" not in self._stubs: - self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources", + if 'batch_migrate_resources' not in self._stubs: + self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources', request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["batch_migrate_resources"] + return self._stubs['batch_migrate_resources'] -__all__ = ("MigrationServiceGrpcTransport",) +__all__ = ( + 'MigrationServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py index dbdddf31e5..600f8893fe 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import migration_service @@ -54,18 +54,16 @@ class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -91,24 +89,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -143,10 +139,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -155,7 +151,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -163,70 +162,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -234,18 +213,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -274,12 +243,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def search_migratable_resources( - self, - ) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - Awaitable[migration_service.SearchMigratableResourcesResponse], - ]: + def search_migratable_resources(self) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse]]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -297,21 +263,18 @@ def search_migratable_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "search_migratable_resources" not in self._stubs: - self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources", + if 'search_migratable_resources' not in self._stubs: + self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources', request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs["search_migratable_resources"] + return self._stubs['search_migratable_resources'] @property - def batch_migrate_resources( - self, - ) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - Awaitable[operations.Operation], - ]: + def batch_migrate_resources(self) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -328,13 +291,15 @@ def batch_migrate_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "batch_migrate_resources" not in self._stubs: - self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources", + if 'batch_migrate_resources' not in self._stubs: + self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources', request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["batch_migrate_resources"] + return self._stubs['batch_migrate_resources'] -__all__ = ("MigrationServiceGrpcAsyncIOTransport",) +__all__ = ( + 'MigrationServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/services/model_service/__init__.py b/google/cloud/aiplatform_v1/services/model_service/__init__.py index b39295ebfe..3ee8fc6e9e 100644 --- a/google/cloud/aiplatform_v1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/model_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import ModelServiceAsyncClient __all__ = ( - "ModelServiceClient", - "ModelServiceAsyncClient", + 'ModelServiceClient', + 'ModelServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index 687c22455a..a65c5df60f 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -62,44 +62,26 @@ class ModelServiceAsyncClient: model_path = staticmethod(ModelServiceClient.model_path) parse_model_path = staticmethod(ModelServiceClient.parse_model_path) model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) - parse_model_evaluation_path = staticmethod( - ModelServiceClient.parse_model_evaluation_path - ) - model_evaluation_slice_path = staticmethod( - ModelServiceClient.model_evaluation_slice_path - ) - parse_model_evaluation_slice_path = staticmethod( - ModelServiceClient.parse_model_evaluation_slice_path - ) + parse_model_evaluation_path = staticmethod(ModelServiceClient.parse_model_evaluation_path) + model_evaluation_slice_path = staticmethod(ModelServiceClient.model_evaluation_slice_path) + parse_model_evaluation_slice_path = staticmethod(ModelServiceClient.parse_model_evaluation_slice_path) training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod( - ModelServiceClient.parse_training_pipeline_path - ) + parse_training_pipeline_path = staticmethod(ModelServiceClient.parse_training_pipeline_path) - common_billing_account_path = staticmethod( - ModelServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - ModelServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(ModelServiceClient.common_folder_path) parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) common_organization_path = staticmethod(ModelServiceClient.common_organization_path) - parse_common_organization_path = staticmethod( - ModelServiceClient.parse_common_organization_path - ) + parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) common_project_path = staticmethod(ModelServiceClient.common_project_path) - parse_common_project_path = staticmethod( - ModelServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) common_location_path = staticmethod(ModelServiceClient.common_location_path) - parse_common_location_path = staticmethod( - ModelServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -142,18 +124,14 @@ def transport(self) -> ModelServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(ModelServiceClient).get_transport_class, type(ModelServiceClient) - ) + get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -192,18 +170,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def upload_model( - self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def upload_model(self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Uploads a Model artifact into AI Platform. Args: @@ -246,10 +224,8 @@ async def upload_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.UploadModelRequest(request) @@ -272,11 +248,18 @@ async def upload_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -289,15 +272,14 @@ async def upload_model( # Done; return the response. return response - async def get_model( - self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + async def get_model(self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -327,10 +309,8 @@ async def get_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.GetModelRequest(request) @@ -351,24 +331,30 @@ async def get_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_models( - self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: + async def list_models(self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: r"""Lists Models in a Location. Args: @@ -404,10 +390,8 @@ async def list_models( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ListModelsRequest(request) @@ -428,31 +412,40 @@ async def list_models( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def update_model( - self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + async def update_model(self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -490,10 +483,8 @@ async def update_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.UpdateModelRequest(request) @@ -516,26 +507,30 @@ async def update_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("model.name", request.model.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('model.name', request.model.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def delete_model( - self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_model(self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -583,10 +578,8 @@ async def delete_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.DeleteModelRequest(request) @@ -607,11 +600,18 @@ async def delete_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -624,16 +624,15 @@ async def delete_model( # Done; return the response. return response - async def export_model( - self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_model(self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -681,10 +680,8 @@ async def export_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ExportModelRequest(request) @@ -707,11 +704,18 @@ async def export_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -724,15 +728,14 @@ async def export_model( # Done; return the response. return response - async def get_model_evaluation( - self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + async def get_model_evaluation(self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -768,10 +771,8 @@ async def get_model_evaluation( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.GetModelEvaluationRequest(request) @@ -792,24 +793,30 @@ async def get_model_evaluation( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_model_evaluations( - self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsAsyncPager: + async def list_model_evaluations(self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: r"""Lists ModelEvaluations in a Model. Args: @@ -845,10 +852,8 @@ async def list_model_evaluations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ListModelEvaluationsRequest(request) @@ -869,30 +874,39 @@ async def list_model_evaluations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def get_model_evaluation_slice( - self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + async def get_model_evaluation_slice(self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -928,10 +942,8 @@ async def get_model_evaluation_slice( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.GetModelEvaluationSliceRequest(request) @@ -952,24 +964,30 @@ async def get_model_evaluation_slice( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_model_evaluation_slices( - self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesAsyncPager: + async def list_model_evaluation_slices(self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1006,10 +1024,8 @@ async def list_model_evaluation_slices( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ListModelEvaluationSlicesRequest(request) @@ -1030,30 +1046,47 @@ async def list_model_evaluation_slices( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationSlicesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("ModelServiceAsyncClient",) +__all__ = ( + 'ModelServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index fa75f3c22b..9d5ebc8008 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,12 +60,13 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry["grpc"] = ModelServiceGrpcTransport - _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport + _transport_registry['grpc'] = ModelServiceGrpcTransport + _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -116,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -151,8 +152,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: ModelServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -167,162 +169,121 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_evaluation_path( - project: str, location: str, model: str, evaluation: str, - ) -> str: + def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: """Return a fully-qualified model_evaluation string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( - project=project, location=location, model=model, evaluation=evaluation, - ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) @staticmethod - def parse_model_evaluation_path(path: str) -> Dict[str, str]: + def parse_model_evaluation_path(path: str) -> Dict[str,str]: """Parse a model_evaluation path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_evaluation_slice_path( - project: str, location: str, model: str, evaluation: str, slice: str, - ) -> str: + def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: """Return a fully-qualified model_evaluation_slice string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( - project=project, - location=location, - model=model, - evaluation=evaluation, - slice=slice, - ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) @staticmethod - def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: + def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: """Parse a model_evaluation_slice path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path( - project: str, location: str, training_pipeline: str, - ) -> str: + def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str, str]: + def parse_training_pipeline_path(path: str) -> Dict[str,str]: """Parse a training_pipeline path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -366,9 +327,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -378,9 +337,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -392,9 +349,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -406,10 +361,8 @@ def __init__( if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -428,16 +381,15 @@ def __init__( client_info=client_info, ) - def upload_model( - self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def upload_model(self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -480,10 +432,8 @@ def upload_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.UploadModelRequest. @@ -507,11 +457,18 @@ def upload_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -524,15 +481,14 @@ def upload_model( # Done; return the response. return response - def get_model( - self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model(self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -562,10 +518,8 @@ def get_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -587,24 +541,30 @@ def get_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_models( - self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models(self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -640,10 +600,8 @@ def list_models( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -665,31 +623,40 @@ def list_models( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def update_model( - self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model(self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -727,10 +694,8 @@ def update_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateModelRequest. @@ -754,26 +719,30 @@ def update_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("model.name", request.model.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('model.name', request.model.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def delete_model( - self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_model(self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -821,10 +790,8 @@ def delete_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteModelRequest. @@ -846,11 +813,18 @@ def delete_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -863,16 +837,15 @@ def delete_model( # Done; return the response. return response - def export_model( - self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_model(self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -920,10 +893,8 @@ def export_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ExportModelRequest. @@ -947,11 +918,18 @@ def export_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -964,15 +942,14 @@ def export_model( # Done; return the response. return response - def get_model_evaluation( - self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation(self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -1008,10 +985,8 @@ def get_model_evaluation( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationRequest. @@ -1033,24 +1008,30 @@ def get_model_evaluation( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_model_evaluations( - self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations(self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -1086,10 +1067,8 @@ def list_model_evaluations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationsRequest. @@ -1111,30 +1090,39 @@ def list_model_evaluations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice( - self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice(self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -1170,10 +1158,8 @@ def get_model_evaluation_slice( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationSliceRequest. @@ -1190,31 +1176,35 @@ def get_model_evaluation_slice( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.get_model_evaluation_slice - ] + rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_model_evaluation_slices( - self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices(self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1251,10 +1241,8 @@ def list_model_evaluation_slices( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationSlicesRequest. @@ -1271,37 +1259,52 @@ def list_model_evaluation_slices( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_model_evaluation_slices - ] + rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("ModelServiceClient",) +__all__ = ( + 'ModelServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/model_service/pagers.py b/google/cloud/aiplatform_v1/services/model_service/pagers.py index d01f0057c1..cf94a17fea 100644 --- a/google/cloud/aiplatform_v1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/model_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1.types import model from google.cloud.aiplatform_v1.types import model_evaluation @@ -49,15 +40,12 @@ class ListModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., model_service.ListModelsResponse], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -91,7 +79,7 @@ def __iter__(self) -> Iterable[model.Model]: yield from page.models def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelsAsyncPager: @@ -111,15 +99,12 @@ class ListModelsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -157,7 +142,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationsPager: @@ -177,15 +162,12 @@ class ListModelEvaluationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., model_service.ListModelEvaluationsResponse], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., model_service.ListModelEvaluationsResponse], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -219,7 +201,7 @@ def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: yield from page.model_evaluations def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationsAsyncPager: @@ -239,15 +221,12 @@ class ListModelEvaluationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -285,7 +264,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesPager: @@ -305,15 +284,12 @@ class ListModelEvaluationSlicesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., model_service.ListModelEvaluationSlicesResponse], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -347,7 +323,7 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: yield from page.model_evaluation_slices def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesAsyncPager: @@ -367,17 +343,12 @@ class ListModelEvaluationSlicesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] - ], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationSlicesResponse]], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -399,9 +370,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -417,4 +386,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py index 5d1cb51abc..833862a1d6 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry["grpc"] = ModelServiceGrpcTransport -_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = ModelServiceGrpcTransport +_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport __all__ = ( - "ModelServiceTransport", - "ModelServiceGrpcTransport", - "ModelServiceGrpcAsyncIOTransport", + 'ModelServiceTransport', + 'ModelServiceGrpcTransport', + 'ModelServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1/services/model_service/transports/base.py index d937f09a61..80c34f3e4a 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -37,29 +37,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -75,60 +75,70 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.upload_model: gapic_v1.method.wrap_method( - self.upload_model, default_timeout=None, client_info=client_info, + self.upload_model, + default_timeout=None, + client_info=client_info, ), self.get_model: gapic_v1.method.wrap_method( - self.get_model, default_timeout=None, client_info=client_info, + self.get_model, + default_timeout=None, + client_info=client_info, ), self.list_models: gapic_v1.method.wrap_method( - self.list_models, default_timeout=None, client_info=client_info, + self.list_models, + default_timeout=None, + client_info=client_info, ), self.update_model: gapic_v1.method.wrap_method( - self.update_model, default_timeout=None, client_info=client_info, + self.update_model, + default_timeout=None, + client_info=client_info, ), self.delete_model: gapic_v1.method.wrap_method( - self.delete_model, default_timeout=None, client_info=client_info, + self.delete_model, + default_timeout=None, + client_info=client_info, ), self.export_model: gapic_v1.method.wrap_method( - self.export_model, default_timeout=None, client_info=client_info, + self.export_model, + default_timeout=None, + client_info=client_info, ), self.get_model_evaluation: gapic_v1.method.wrap_method( self.get_model_evaluation, @@ -150,6 +160,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + } @property @@ -158,109 +169,96 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def upload_model( - self, - ) -> typing.Callable[ - [model_service.UploadModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def upload_model(self) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_model( - self, - ) -> typing.Callable[ - [model_service.GetModelRequest], - typing.Union[model.Model, typing.Awaitable[model.Model]], - ]: + def get_model(self) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[ + model.Model, + typing.Awaitable[model.Model] + ]]: raise NotImplementedError() @property - def list_models( - self, - ) -> typing.Callable[ - [model_service.ListModelsRequest], - typing.Union[ - model_service.ListModelsResponse, - typing.Awaitable[model_service.ListModelsResponse], - ], - ]: + def list_models(self) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse] + ]]: raise NotImplementedError() @property - def update_model( - self, - ) -> typing.Callable[ - [model_service.UpdateModelRequest], - typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], - ]: + def update_model(self) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[ + gca_model.Model, + typing.Awaitable[gca_model.Model] + ]]: raise NotImplementedError() @property - def delete_model( - self, - ) -> typing.Callable[ - [model_service.DeleteModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_model(self) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def export_model( - self, - ) -> typing.Callable[ - [model_service.ExportModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def export_model(self) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_model_evaluation( - self, - ) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], - typing.Union[ - model_evaluation.ModelEvaluation, - typing.Awaitable[model_evaluation.ModelEvaluation], - ], - ]: + def get_model_evaluation(self) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation] + ]]: raise NotImplementedError() @property - def list_model_evaluations( - self, - ) -> typing.Callable[ - [model_service.ListModelEvaluationsRequest], - typing.Union[ - model_service.ListModelEvaluationsResponse, - typing.Awaitable[model_service.ListModelEvaluationsResponse], - ], - ]: + def list_model_evaluations(self) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse] + ]]: raise NotImplementedError() @property - def get_model_evaluation_slice( - self, - ) -> typing.Callable[ - [model_service.GetModelEvaluationSliceRequest], - typing.Union[ - model_evaluation_slice.ModelEvaluationSlice, - typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], - ], - ]: + def get_model_evaluation_slice(self) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice] + ]]: raise NotImplementedError() @property - def list_model_evaluation_slices( - self, - ) -> typing.Callable[ - [model_service.ListModelEvaluationSlicesRequest], - typing.Union[ - model_service.ListModelEvaluationSlicesResponse, - typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], - ], - ]: + def list_model_evaluation_slices(self) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse] + ]]: raise NotImplementedError() -__all__ = ("ModelServiceTransport",) +__all__ = ( + 'ModelServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py index b6f2efb427..d05154e2fb 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -49,24 +49,21 @@ class ModelServiceGrpcTransport(ModelServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -112,7 +109,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -120,70 +120,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -191,32 +171,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -246,12 +214,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -263,15 +232,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def upload_model( - self, - ) -> Callable[[model_service.UploadModelRequest], operations.Operation]: + def upload_model(self) -> Callable[ + [model_service.UploadModelRequest], + operations.Operation]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -286,16 +257,18 @@ def upload_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "upload_model" not in self._stubs: - self._stubs["upload_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/UploadModel", + if 'upload_model' not in self._stubs: + self._stubs['upload_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/UploadModel', request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["upload_model"] + return self._stubs['upload_model'] @property - def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + model.Model]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -310,18 +283,18 @@ def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model" not in self._stubs: - self._stubs["get_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/GetModel", + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/GetModel', request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs["get_model"] + return self._stubs['get_model'] @property - def list_models( - self, - ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + model_service.ListModelsResponse]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -336,18 +309,18 @@ def list_models( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_models" not in self._stubs: - self._stubs["list_models"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ListModels", + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ListModels', request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs["list_models"] + return self._stubs['list_models'] @property - def update_model( - self, - ) -> Callable[[model_service.UpdateModelRequest], gca_model.Model]: + def update_model(self) -> Callable[ + [model_service.UpdateModelRequest], + gca_model.Model]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -362,18 +335,18 @@ def update_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_model" not in self._stubs: - self._stubs["update_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/UpdateModel", + if 'update_model' not in self._stubs: + self._stubs['update_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/UpdateModel', request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs["update_model"] + return self._stubs['update_model'] @property - def delete_model( - self, - ) -> Callable[[model_service.DeleteModelRequest], operations.Operation]: + def delete_model(self) -> Callable[ + [model_service.DeleteModelRequest], + operations.Operation]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -390,18 +363,18 @@ def delete_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_model" not in self._stubs: - self._stubs["delete_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/DeleteModel", + if 'delete_model' not in self._stubs: + self._stubs['delete_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/DeleteModel', request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_model"] + return self._stubs['delete_model'] @property - def export_model( - self, - ) -> Callable[[model_service.ExportModelRequest], operations.Operation]: + def export_model(self) -> Callable[ + [model_service.ExportModelRequest], + operations.Operation]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -419,20 +392,18 @@ def export_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_model" not in self._stubs: - self._stubs["export_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ExportModel", + if 'export_model' not in self._stubs: + self._stubs['export_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ExportModel', request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_model"] + return self._stubs['export_model'] @property - def get_model_evaluation( - self, - ) -> Callable[ - [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation - ]: + def get_model_evaluation(self) -> Callable[ + [model_service.GetModelEvaluationRequest], + model_evaluation.ModelEvaluation]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -447,21 +418,18 @@ def get_model_evaluation( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation" not in self._stubs: - self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation", + if 'get_model_evaluation' not in self._stubs: + self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation', request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs["get_model_evaluation"] + return self._stubs['get_model_evaluation'] @property - def list_model_evaluations( - self, - ) -> Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse, - ]: + def list_model_evaluations(self) -> Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -476,21 +444,18 @@ def list_model_evaluations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluations" not in self._stubs: - self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations", + if 'list_model_evaluations' not in self._stubs: + self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations', request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs["list_model_evaluations"] + return self._stubs['list_model_evaluations'] @property - def get_model_evaluation_slice( - self, - ) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice, - ]: + def get_model_evaluation_slice(self) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + model_evaluation_slice.ModelEvaluationSlice]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -505,21 +470,18 @@ def get_model_evaluation_slice( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation_slice" not in self._stubs: - self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice", + if 'get_model_evaluation_slice' not in self._stubs: + self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice', request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs["get_model_evaluation_slice"] + return self._stubs['get_model_evaluation_slice'] @property - def list_model_evaluation_slices( - self, - ) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse, - ]: + def list_model_evaluation_slices(self) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -534,13 +496,15 @@ def list_model_evaluation_slices( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluation_slices" not in self._stubs: - self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices", + if 'list_model_evaluation_slices' not in self._stubs: + self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices', request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs["list_model_evaluation_slices"] + return self._stubs['list_model_evaluation_slices'] -__all__ = ("ModelServiceGrpcTransport",) +__all__ = ( + 'ModelServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py index 2aeffea93f..1e24fe3d5c 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import model @@ -56,18 +56,16 @@ class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -93,24 +91,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -145,10 +141,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -157,7 +153,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -165,70 +164,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -236,18 +215,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -276,9 +245,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def upload_model( - self, - ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: + def upload_model(self) -> Callable[ + [model_service.UploadModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -293,18 +262,18 @@ def upload_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "upload_model" not in self._stubs: - self._stubs["upload_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/UploadModel", + if 'upload_model' not in self._stubs: + self._stubs['upload_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/UploadModel', request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["upload_model"] + return self._stubs['upload_model'] @property - def get_model( - self, - ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + Awaitable[model.Model]]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -319,20 +288,18 @@ def get_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model" not in self._stubs: - self._stubs["get_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/GetModel", + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/GetModel', request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs["get_model"] + return self._stubs['get_model'] @property - def list_models( - self, - ) -> Callable[ - [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] - ]: + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + Awaitable[model_service.ListModelsResponse]]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -347,18 +314,18 @@ def list_models( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_models" not in self._stubs: - self._stubs["list_models"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ListModels", + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ListModels', request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs["list_models"] + return self._stubs['list_models'] @property - def update_model( - self, - ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: + def update_model(self) -> Callable[ + [model_service.UpdateModelRequest], + Awaitable[gca_model.Model]]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -373,18 +340,18 @@ def update_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_model" not in self._stubs: - self._stubs["update_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/UpdateModel", + if 'update_model' not in self._stubs: + self._stubs['update_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/UpdateModel', request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs["update_model"] + return self._stubs['update_model'] @property - def delete_model( - self, - ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: + def delete_model(self) -> Callable[ + [model_service.DeleteModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -401,18 +368,18 @@ def delete_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_model" not in self._stubs: - self._stubs["delete_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/DeleteModel", + if 'delete_model' not in self._stubs: + self._stubs['delete_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/DeleteModel', request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_model"] + return self._stubs['delete_model'] @property - def export_model( - self, - ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: + def export_model(self) -> Callable[ + [model_service.ExportModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -430,21 +397,18 @@ def export_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_model" not in self._stubs: - self._stubs["export_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ExportModel", + if 'export_model' not in self._stubs: + self._stubs['export_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ExportModel', request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_model"] + return self._stubs['export_model'] @property - def get_model_evaluation( - self, - ) -> Callable[ - [model_service.GetModelEvaluationRequest], - Awaitable[model_evaluation.ModelEvaluation], - ]: + def get_model_evaluation(self) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation]]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -459,21 +423,18 @@ def get_model_evaluation( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation" not in self._stubs: - self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation", + if 'get_model_evaluation' not in self._stubs: + self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation', request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs["get_model_evaluation"] + return self._stubs['get_model_evaluation'] @property - def list_model_evaluations( - self, - ) -> Callable[ - [model_service.ListModelEvaluationsRequest], - Awaitable[model_service.ListModelEvaluationsResponse], - ]: + def list_model_evaluations(self) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse]]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -488,21 +449,18 @@ def list_model_evaluations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluations" not in self._stubs: - self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations", + if 'list_model_evaluations' not in self._stubs: + self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations', request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs["list_model_evaluations"] + return self._stubs['list_model_evaluations'] @property - def get_model_evaluation_slice( - self, - ) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - Awaitable[model_evaluation_slice.ModelEvaluationSlice], - ]: + def get_model_evaluation_slice(self) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice]]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -517,21 +475,18 @@ def get_model_evaluation_slice( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation_slice" not in self._stubs: - self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice", + if 'get_model_evaluation_slice' not in self._stubs: + self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice', request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs["get_model_evaluation_slice"] + return self._stubs['get_model_evaluation_slice'] @property - def list_model_evaluation_slices( - self, - ) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - Awaitable[model_service.ListModelEvaluationSlicesResponse], - ]: + def list_model_evaluation_slices(self) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse]]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -546,13 +501,15 @@ def list_model_evaluation_slices( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluation_slices" not in self._stubs: - self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices", + if 'list_model_evaluation_slices' not in self._stubs: + self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices', request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs["list_model_evaluation_slices"] + return self._stubs['list_model_evaluation_slices'] -__all__ = ("ModelServiceGrpcAsyncIOTransport",) +__all__ = ( + 'ModelServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py index 7f02b47358..f7f4d9b9ac 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PipelineServiceAsyncClient __all__ = ( - "PipelineServiceClient", - "PipelineServiceAsyncClient", + 'PipelineServiceClient', + 'PipelineServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index fc7337a7a3..276c0980f5 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -61,38 +61,22 @@ class PipelineServiceAsyncClient: model_path = staticmethod(PipelineServiceClient.model_path) parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod( - PipelineServiceClient.parse_training_pipeline_path - ) + parse_training_pipeline_path = staticmethod(PipelineServiceClient.parse_training_pipeline_path) - common_billing_account_path = staticmethod( - PipelineServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - PipelineServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(PipelineServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(PipelineServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - PipelineServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(PipelineServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - PipelineServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - PipelineServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(PipelineServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(PipelineServiceClient.parse_common_organization_path) common_project_path = staticmethod(PipelineServiceClient.common_project_path) - parse_common_project_path = staticmethod( - PipelineServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(PipelineServiceClient.parse_common_project_path) common_location_path = staticmethod(PipelineServiceClient.common_location_path) - parse_common_location_path = staticmethod( - PipelineServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(PipelineServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -135,18 +119,14 @@ def transport(self) -> PipelineServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) - ) + get_transport_class = functools.partial(type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -185,18 +165,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_training_pipeline( - self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + async def create_training_pipeline(self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -241,10 +221,8 @@ async def create_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.CreateTrainingPipelineRequest(request) @@ -267,24 +245,30 @@ async def create_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_training_pipeline( - self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + async def get_training_pipeline(self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -322,10 +306,8 @@ async def get_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.GetTrainingPipelineRequest(request) @@ -346,24 +328,30 @@ async def get_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_training_pipelines( - self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesAsyncPager: + async def list_training_pipelines(self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: r"""Lists TrainingPipelines in a Location. Args: @@ -399,10 +387,8 @@ async def list_training_pipelines( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.ListTrainingPipelinesRequest(request) @@ -423,30 +409,39 @@ async def list_training_pipelines( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListTrainingPipelinesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_training_pipeline( - self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_training_pipeline(self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a TrainingPipeline. Args: @@ -493,10 +488,8 @@ async def delete_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.DeleteTrainingPipelineRequest(request) @@ -517,11 +510,18 @@ async def delete_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -534,15 +534,14 @@ async def delete_training_pipeline( # Done; return the response. return response - async def cancel_training_pipeline( - self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_training_pipeline(self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -582,10 +581,8 @@ async def cancel_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.CancelTrainingPipelineRequest(request) @@ -606,23 +603,35 @@ async def cancel_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PipelineServiceAsyncClient",) +__all__ = ( + 'PipelineServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index 39f37eb72e..fe36174dda 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -59,14 +59,13 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry['grpc'] = PipelineServiceGrpcTransport + _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry["grpc"] = PipelineServiceGrpcTransport - _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +116,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,8 +151,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PipelineServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -168,122 +168,99 @@ def transport(self) -> PipelineServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path( - project: str, location: str, training_pipeline: str, - ) -> str: + def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str, str]: + def parse_training_pipeline_path(path: str) -> Dict[str,str]: """Parse a training_pipeline path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -327,9 +304,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -339,9 +314,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -353,9 +326,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -367,10 +338,8 @@ def __init__( if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -389,16 +358,15 @@ def __init__( client_info=client_info, ) - def create_training_pipeline( - self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline(self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -443,10 +411,8 @@ def create_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CreateTrainingPipelineRequest. @@ -470,24 +436,30 @@ def create_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_training_pipeline( - self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline(self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -525,10 +497,8 @@ def get_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.GetTrainingPipelineRequest. @@ -550,24 +520,30 @@ def get_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_training_pipelines( - self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines(self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -603,10 +579,8 @@ def list_training_pipelines( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.ListTrainingPipelinesRequest. @@ -628,30 +602,39 @@ def list_training_pipelines( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline( - self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_training_pipeline(self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -698,10 +681,8 @@ def delete_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.DeleteTrainingPipelineRequest. @@ -723,11 +704,18 @@ def delete_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -740,15 +728,14 @@ def delete_training_pipeline( # Done; return the response. return response - def cancel_training_pipeline( - self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline(self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -788,10 +775,8 @@ def cancel_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CancelTrainingPipelineRequest. @@ -813,23 +798,35 @@ def cancel_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PipelineServiceClient",) +__all__ = ( + 'PipelineServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py index 987c37dba2..ec626400ec 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1.types import pipeline_service from google.cloud.aiplatform_v1.types import training_pipeline @@ -47,15 +38,12 @@ class ListTrainingPipelinesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: yield from page.training_pipelines def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListTrainingPipelinesAsyncPager: @@ -109,17 +97,12 @@ class ListTrainingPipelinesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] - ], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[pipeline_service.ListTrainingPipelinesResponse]], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -141,9 +124,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + async def pages(self) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -159,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py index 9d4610087a..f289718f83 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] -_transport_registry["grpc"] = PipelineServiceGrpcTransport -_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = PipelineServiceGrpcTransport +_transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport __all__ = ( - "PipelineServiceTransport", - "PipelineServiceGrpcTransport", - "PipelineServiceGrpcAsyncIOTransport", + 'PipelineServiceTransport', + 'PipelineServiceGrpcTransport', + 'PipelineServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py index e4bc8e66a8..3a0cfa5a08 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -74,40 +74,38 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -136,6 +134,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + } @property @@ -144,58 +143,51 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - typing.Union[ - gca_training_pipeline.TrainingPipeline, - typing.Awaitable[gca_training_pipeline.TrainingPipeline], - ], - ]: + def create_training_pipeline(self) -> typing.Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline] + ]]: raise NotImplementedError() @property - def get_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.GetTrainingPipelineRequest], - typing.Union[ - training_pipeline.TrainingPipeline, - typing.Awaitable[training_pipeline.TrainingPipeline], - ], - ]: + def get_training_pipeline(self) -> typing.Callable[ + [pipeline_service.GetTrainingPipelineRequest], + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline] + ]]: raise NotImplementedError() @property - def list_training_pipelines( - self, - ) -> typing.Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - typing.Union[ - pipeline_service.ListTrainingPipelinesResponse, - typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], - ], - ]: + def list_training_pipelines(self) -> typing.Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ]]: raise NotImplementedError() @property - def delete_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_training_pipeline(self) -> typing.Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_training_pipeline(self) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() -__all__ = ("PipelineServiceTransport",) +__all__ = ( + 'PipelineServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py index b7d20db080..4f19145175 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -48,24 +48,21 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -111,7 +108,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -119,70 +119,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -190,32 +170,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -245,12 +213,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -262,18 +231,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline, - ]: + def create_training_pipeline(self) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + gca_training_pipeline.TrainingPipeline]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -289,21 +257,18 @@ def create_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_training_pipeline" not in self._stubs: - self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline", + if 'create_training_pipeline' not in self._stubs: + self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline', request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["create_training_pipeline"] + return self._stubs['create_training_pipeline'] @property - def get_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline, - ]: + def get_training_pipeline(self) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + training_pipeline.TrainingPipeline]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -318,21 +283,18 @@ def get_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_training_pipeline" not in self._stubs: - self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline", + if 'get_training_pipeline' not in self._stubs: + self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline', request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["get_training_pipeline"] + return self._stubs['get_training_pipeline'] @property - def list_training_pipelines( - self, - ) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse, - ]: + def list_training_pipelines(self) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -347,20 +309,18 @@ def list_training_pipelines( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_training_pipelines" not in self._stubs: - self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines", + if 'list_training_pipelines' not in self._stubs: + self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines', request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs["list_training_pipelines"] + return self._stubs['list_training_pipelines'] @property - def delete_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation - ]: + def delete_training_pipeline(self) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + operations.Operation]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -375,18 +335,18 @@ def delete_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_training_pipeline" not in self._stubs: - self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline", + if 'delete_training_pipeline' not in self._stubs: + self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline', request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_training_pipeline"] + return self._stubs['delete_training_pipeline'] @property - def cancel_training_pipeline( - self, - ) -> Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: + def cancel_training_pipeline(self) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + empty.Empty]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -413,13 +373,15 @@ def cancel_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_training_pipeline" not in self._stubs: - self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline", + if 'cancel_training_pipeline' not in self._stubs: + self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline', request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_training_pipeline"] + return self._stubs['cancel_training_pipeline'] -__all__ = ("PipelineServiceGrpcTransport",) +__all__ = ( + 'PipelineServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py index ceed94071f..8a0f1f7534 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import pipeline_service @@ -55,18 +55,16 @@ class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -92,24 +90,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -144,10 +140,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -156,7 +152,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -164,70 +163,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -235,18 +214,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -275,12 +244,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - Awaitable[gca_training_pipeline.TrainingPipeline], - ]: + def create_training_pipeline(self) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline]]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -296,21 +262,18 @@ def create_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_training_pipeline" not in self._stubs: - self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline", + if 'create_training_pipeline' not in self._stubs: + self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline', request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["create_training_pipeline"] + return self._stubs['create_training_pipeline'] @property - def get_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - Awaitable[training_pipeline.TrainingPipeline], - ]: + def get_training_pipeline(self) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline]]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -325,21 +288,18 @@ def get_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_training_pipeline" not in self._stubs: - self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline", + if 'get_training_pipeline' not in self._stubs: + self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline', request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["get_training_pipeline"] + return self._stubs['get_training_pipeline'] @property - def list_training_pipelines( - self, - ) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - Awaitable[pipeline_service.ListTrainingPipelinesResponse], - ]: + def list_training_pipelines(self) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse]]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -354,21 +314,18 @@ def list_training_pipelines( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_training_pipelines" not in self._stubs: - self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines", + if 'list_training_pipelines' not in self._stubs: + self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines', request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs["list_training_pipelines"] + return self._stubs['list_training_pipelines'] @property - def delete_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - Awaitable[operations.Operation], - ]: + def delete_training_pipeline(self) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -383,20 +340,18 @@ def delete_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_training_pipeline" not in self._stubs: - self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline", + if 'delete_training_pipeline' not in self._stubs: + self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline', request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_training_pipeline"] + return self._stubs['delete_training_pipeline'] @property - def cancel_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] - ]: + def cancel_training_pipeline(self) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -423,13 +378,15 @@ def cancel_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_training_pipeline" not in self._stubs: - self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline", + if 'cancel_training_pipeline' not in self._stubs: + self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline', request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_training_pipeline"] + return self._stubs['cancel_training_pipeline'] -__all__ = ("PipelineServiceGrpcAsyncIOTransport",) +__all__ = ( + 'PipelineServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1/services/prediction_service/__init__.py index 0c847693e0..d4047c335d 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PredictionServiceAsyncClient __all__ = ( - "PredictionServiceClient", - "PredictionServiceAsyncClient", + 'PredictionServiceClient', + 'PredictionServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index cc6d011e88..299694bdce 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import prediction_service from google.protobuf import struct_pb2 as struct # type: ignore @@ -47,34 +47,20 @@ class PredictionServiceAsyncClient: endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) - common_billing_account_path = staticmethod( - PredictionServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - PredictionServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(PredictionServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(PredictionServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - PredictionServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(PredictionServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - PredictionServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - PredictionServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(PredictionServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(PredictionServiceClient.parse_common_organization_path) common_project_path = staticmethod(PredictionServiceClient.common_project_path) - parse_common_project_path = staticmethod( - PredictionServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(PredictionServiceClient.parse_common_project_path) common_location_path = staticmethod(PredictionServiceClient.common_location_path) - parse_common_location_path = staticmethod( - PredictionServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(PredictionServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -117,18 +103,14 @@ def transport(self) -> PredictionServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) - ) + get_transport_class = functools.partial(type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -167,19 +149,19 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def predict( - self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + async def predict(self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -239,10 +221,8 @@ async def predict( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = prediction_service.PredictRequest(request) @@ -268,24 +248,38 @@ async def predict( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PredictionServiceAsyncClient",) +__all__ = ( + 'PredictionServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py index 029fb851b8..7d9294a251 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import prediction_service from google.protobuf import struct_pb2 as struct # type: ignore @@ -47,16 +47,13 @@ class PredictionServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry['grpc'] = PredictionServiceGrpcTransport + _transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[PredictionServiceTransport]] - _transport_registry["grpc"] = PredictionServiceGrpcTransport - _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport - - def get_transport_class( - cls, label: str = None, - ) -> Type[PredictionServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[PredictionServiceTransport]: """Return an appropriate transport class. Args: @@ -107,7 +104,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -142,8 +139,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PredictionServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -158,88 +156,77 @@ def transport(self) -> PredictionServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -283,9 +270,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -295,9 +280,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -309,9 +292,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -323,10 +304,8 @@ def __init__( if isinstance(transport, PredictionServiceTransport): # transport is a PredictionServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -345,17 +324,16 @@ def __init__( client_info=client_info, ) - def predict( - self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + def predict(self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -415,10 +393,8 @@ def predict( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a prediction_service.PredictRequest. @@ -444,24 +420,38 @@ def predict( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PredictionServiceClient",) +__all__ = ( + 'PredictionServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py index 9ec1369a05..15b5acb198 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] -_transport_registry["grpc"] = PredictionServiceGrpcTransport -_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = PredictionServiceGrpcTransport +_transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport __all__ = ( - "PredictionServiceTransport", - "PredictionServiceGrpcTransport", - "PredictionServiceGrpcAsyncIOTransport", + 'PredictionServiceTransport', + 'PredictionServiceGrpcTransport', + 'PredictionServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py index 311639daaf..9e8a9841c0 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore @@ -31,29 +31,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -69,59 +69,59 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.predict: gapic_v1.method.wrap_method( - self.predict, default_timeout=None, client_info=client_info, + self.predict, + default_timeout=None, + client_info=client_info, ), + } @property - def predict( - self, - ) -> typing.Callable[ - [prediction_service.PredictRequest], - typing.Union[ - prediction_service.PredictResponse, - typing.Awaitable[prediction_service.PredictResponse], - ], - ]: + def predict(self) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse] + ]]: raise NotImplementedError() -__all__ = ("PredictionServiceTransport",) +__all__ = ( + 'PredictionServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py index 86aef5e81a..484a1193b1 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -18,10 +18,10 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -43,24 +43,21 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -106,7 +103,9 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -114,70 +113,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -185,31 +164,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -239,20 +207,19 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property - def predict( - self, - ) -> Callable[ - [prediction_service.PredictRequest], prediction_service.PredictResponse - ]: + def predict(self) -> Callable[ + [prediction_service.PredictRequest], + prediction_service.PredictResponse]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -267,13 +234,15 @@ def predict( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "predict" not in self._stubs: - self._stubs["predict"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PredictionService/Predict", + if 'predict' not in self._stubs: + self._stubs['predict'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PredictionService/Predict', request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs["predict"] + return self._stubs['predict'] -__all__ = ("PredictionServiceGrpcTransport",) +__all__ = ( + 'PredictionServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py index 620f340813..87a9970365 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py @@ -18,13 +18,13 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import prediction_service @@ -50,18 +50,16 @@ class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -87,24 +85,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -139,10 +135,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -151,7 +147,9 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -159,70 +157,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -230,17 +208,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -253,12 +222,9 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def predict( - self, - ) -> Callable[ - [prediction_service.PredictRequest], - Awaitable[prediction_service.PredictResponse], - ]: + def predict(self) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse]]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -273,13 +239,15 @@ def predict( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "predict" not in self._stubs: - self._stubs["predict"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.PredictionService/Predict", + if 'predict' not in self._stubs: + self._stubs['predict'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.PredictionService/Predict', request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs["predict"] + return self._stubs['predict'] -__all__ = ("PredictionServiceGrpcAsyncIOTransport",) +__all__ = ( + 'PredictionServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py index 49e9cdf0a0..e4247d7758 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import SpecialistPoolServiceAsyncClient __all__ = ( - "SpecialistPoolServiceClient", - "SpecialistPoolServiceAsyncClient", + 'SpecialistPoolServiceClient', + 'SpecialistPoolServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index 57e2b8a0a7..be193ead83 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -57,43 +57,23 @@ class SpecialistPoolServiceAsyncClient: DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT - specialist_pool_path = staticmethod( - SpecialistPoolServiceClient.specialist_pool_path - ) - parse_specialist_pool_path = staticmethod( - SpecialistPoolServiceClient.parse_specialist_pool_path - ) + specialist_pool_path = staticmethod(SpecialistPoolServiceClient.specialist_pool_path) + parse_specialist_pool_path = staticmethod(SpecialistPoolServiceClient.parse_specialist_pool_path) - common_billing_account_path = staticmethod( - SpecialistPoolServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - SpecialistPoolServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(SpecialistPoolServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(SpecialistPoolServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - SpecialistPoolServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(SpecialistPoolServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - SpecialistPoolServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - SpecialistPoolServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(SpecialistPoolServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(SpecialistPoolServiceClient.parse_common_organization_path) common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) - parse_common_project_path = staticmethod( - SpecialistPoolServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(SpecialistPoolServiceClient.parse_common_project_path) - common_location_path = staticmethod( - SpecialistPoolServiceClient.common_location_path - ) - parse_common_location_path = staticmethod( - SpecialistPoolServiceClient.parse_common_location_path - ) + common_location_path = staticmethod(SpecialistPoolServiceClient.common_location_path) + parse_common_location_path = staticmethod(SpecialistPoolServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -136,19 +116,14 @@ def transport(self) -> SpecialistPoolServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(SpecialistPoolServiceClient).get_transport_class, - type(SpecialistPoolServiceClient), - ) + get_transport_class = functools.partial(type(SpecialistPoolServiceClient).get_transport_class, type(SpecialistPoolServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -187,18 +162,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_specialist_pool( - self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_specialist_pool(self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a SpecialistPool. Args: @@ -246,10 +221,8 @@ async def create_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.CreateSpecialistPoolRequest(request) @@ -272,11 +245,18 @@ async def create_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -289,15 +269,14 @@ async def create_specialist_pool( # Done; return the response. return response - async def get_specialist_pool( - self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + async def get_specialist_pool(self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -340,10 +319,8 @@ async def get_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.GetSpecialistPoolRequest(request) @@ -364,24 +341,30 @@ async def get_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_specialist_pools( - self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsAsyncPager: + async def list_specialist_pools(self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: r"""Lists SpecialistPools in a Location. Args: @@ -417,10 +400,8 @@ async def list_specialist_pools( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.ListSpecialistPoolsRequest(request) @@ -441,30 +422,39 @@ async def list_specialist_pools( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListSpecialistPoolsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_specialist_pool( - self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_specialist_pool(self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -511,10 +501,8 @@ async def delete_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.DeleteSpecialistPoolRequest(request) @@ -535,11 +523,18 @@ async def delete_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -552,16 +547,15 @@ async def delete_specialist_pool( # Done; return the response. return response - async def update_specialist_pool( - self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_specialist_pool(self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates a SpecialistPool. Args: @@ -608,10 +602,8 @@ async def update_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.UpdateSpecialistPoolRequest(request) @@ -634,13 +626,18 @@ async def update_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("specialist_pool.name", request.specialist_pool.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('specialist_pool.name', request.specialist_pool.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -654,14 +651,21 @@ async def update_specialist_pool( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("SpecialistPoolServiceAsyncClient",) +__all__ = ( + 'SpecialistPoolServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index c6429b54f8..efb32eaa6e 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -54,16 +54,13 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport + _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport - _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport - - def get_transport_class( - cls, label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -120,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -155,8 +152,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: SpecialistPoolServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -171,88 +169,77 @@ def transport(self) -> SpecialistPoolServiceTransport: return self._transport @staticmethod - def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: + def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str, str]: + def parse_specialist_pool_path(path: str) -> Dict[str,str]: """Parse a specialist_pool path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -296,9 +283,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -308,9 +293,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -322,9 +305,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -336,10 +317,8 @@ def __init__( if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -358,16 +337,15 @@ def __init__( client_info=client_info, ) - def create_specialist_pool( - self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_specialist_pool(self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -415,10 +393,8 @@ def create_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.CreateSpecialistPoolRequest. @@ -442,11 +418,18 @@ def create_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -459,15 +442,14 @@ def create_specialist_pool( # Done; return the response. return response - def get_specialist_pool( - self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool(self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -510,10 +492,8 @@ def get_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.GetSpecialistPoolRequest. @@ -535,24 +515,30 @@ def get_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_specialist_pools( - self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools(self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -588,10 +574,8 @@ def list_specialist_pools( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.ListSpecialistPoolsRequest. @@ -613,30 +597,39 @@ def list_specialist_pools( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool( - self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_specialist_pool(self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -683,10 +676,8 @@ def delete_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.DeleteSpecialistPoolRequest. @@ -708,11 +699,18 @@ def delete_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -725,16 +723,15 @@ def delete_specialist_pool( # Done; return the response. return response - def update_specialist_pool( - self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def update_specialist_pool(self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -781,10 +778,8 @@ def update_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.UpdateSpecialistPoolRequest. @@ -808,13 +803,18 @@ def update_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("specialist_pool.name", request.specialist_pool.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('specialist_pool.name', request.specialist_pool.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -828,14 +828,21 @@ def update_specialist_pool( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("SpecialistPoolServiceClient",) +__all__ = ( + 'SpecialistPoolServiceClient', +) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py index e64a827049..87590e0e87 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1.types import specialist_pool from google.cloud.aiplatform_v1.types import specialist_pool_service @@ -47,15 +38,12 @@ class ListSpecialistPoolsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: yield from page.specialist_pools def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListSpecialistPoolsAsyncPager: @@ -109,17 +97,12 @@ class ListSpecialistPoolsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] - ], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -141,9 +124,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + async def pages(self) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -159,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py index 1bb2fbf22a..80de7b209f 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py @@ -24,14 +24,12 @@ # Compile a registry of transports. -_transport_registry = ( - OrderedDict() -) # type: Dict[str, Type[SpecialistPoolServiceTransport]] -_transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport -_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport +_transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] +_transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport +_transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( - "SpecialistPoolServiceTransport", - "SpecialistPoolServiceGrpcTransport", - "SpecialistPoolServiceGrpcAsyncIOTransport", + 'SpecialistPoolServiceTransport', + 'SpecialistPoolServiceGrpcTransport', + 'SpecialistPoolServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py index 56de21b988..878e095edb 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,29 +34,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -72,40 +72,38 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -115,7 +113,9 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_specialist_pool: gapic_v1.method.wrap_method( - self.get_specialist_pool, default_timeout=None, client_info=client_info, + self.get_specialist_pool, + default_timeout=None, + client_info=client_info, ), self.list_specialist_pools: gapic_v1.method.wrap_method( self.list_specialist_pools, @@ -132,6 +132,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + } @property @@ -140,55 +141,51 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def create_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - typing.Union[ - specialist_pool.SpecialistPool, - typing.Awaitable[specialist_pool.SpecialistPool], - ], - ]: + def get_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool] + ]]: raise NotImplementedError() @property - def list_specialist_pools( - self, - ) -> typing.Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - typing.Union[ - specialist_pool_service.ListSpecialistPoolsResponse, - typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], - ], - ]: + def list_specialist_pools(self) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ]]: raise NotImplementedError() @property - def delete_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def update_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def update_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() -__all__ = ("SpecialistPoolServiceTransport",) +__all__ = ( + 'SpecialistPoolServiceTransport', +) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py index cb8904bc07..7574c12f22 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,24 +51,21 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -114,7 +111,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -122,70 +122,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -193,32 +173,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -248,12 +216,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -265,17 +234,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation - ]: + def create_specialist_pool(self) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -290,21 +259,18 @@ def create_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_specialist_pool" not in self._stubs: - self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool", + if 'create_specialist_pool' not in self._stubs: + self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool', request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_specialist_pool"] + return self._stubs['create_specialist_pool'] @property - def get_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool, - ]: + def get_specialist_pool(self) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + specialist_pool.SpecialistPool]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -319,21 +285,18 @@ def get_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_specialist_pool" not in self._stubs: - self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool", + if 'get_specialist_pool' not in self._stubs: + self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool', request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs["get_specialist_pool"] + return self._stubs['get_specialist_pool'] @property - def list_specialist_pools( - self, - ) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse, - ]: + def list_specialist_pools(self) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -348,20 +311,18 @@ def list_specialist_pools( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_specialist_pools" not in self._stubs: - self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools", + if 'list_specialist_pools' not in self._stubs: + self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools', request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs["list_specialist_pools"] + return self._stubs['list_specialist_pools'] @property - def delete_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation - ]: + def delete_specialist_pool(self) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -377,20 +338,18 @@ def delete_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_specialist_pool" not in self._stubs: - self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool", + if 'delete_specialist_pool' not in self._stubs: + self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool', request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_specialist_pool"] + return self._stubs['delete_specialist_pool'] @property - def update_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation - ]: + def update_specialist_pool(self) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -405,13 +364,15 @@ def update_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_specialist_pool" not in self._stubs: - self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool", + if 'update_specialist_pool' not in self._stubs: + self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool', request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["update_specialist_pool"] + return self._stubs['update_specialist_pool'] -__all__ = ("SpecialistPoolServiceGrpcTransport",) +__all__ = ( + 'SpecialistPoolServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py index 566d0b022b..2766d7848b 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import specialist_pool @@ -58,18 +58,16 @@ class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -95,24 +93,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -147,10 +143,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -159,7 +155,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -167,70 +166,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -238,18 +217,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -278,12 +247,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: + def create_specialist_pool(self) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -298,21 +264,18 @@ def create_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_specialist_pool" not in self._stubs: - self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool", + if 'create_specialist_pool' not in self._stubs: + self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool', request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_specialist_pool"] + return self._stubs['create_specialist_pool'] @property - def get_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - Awaitable[specialist_pool.SpecialistPool], - ]: + def get_specialist_pool(self) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool]]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -327,21 +290,18 @@ def get_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_specialist_pool" not in self._stubs: - self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool", + if 'get_specialist_pool' not in self._stubs: + self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool', request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs["get_specialist_pool"] + return self._stubs['get_specialist_pool'] @property - def list_specialist_pools( - self, - ) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], - ]: + def list_specialist_pools(self) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -356,21 +316,18 @@ def list_specialist_pools( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_specialist_pools" not in self._stubs: - self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools", + if 'list_specialist_pools' not in self._stubs: + self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools', request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs["list_specialist_pools"] + return self._stubs['list_specialist_pools'] @property - def delete_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: + def delete_specialist_pool(self) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -386,21 +343,18 @@ def delete_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_specialist_pool" not in self._stubs: - self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool", + if 'delete_specialist_pool' not in self._stubs: + self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool', request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_specialist_pool"] + return self._stubs['delete_specialist_pool'] @property - def update_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: + def update_specialist_pool(self) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -415,13 +369,15 @@ def update_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_specialist_pool" not in self._stubs: - self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool", + if 'update_specialist_pool' not in self._stubs: + self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool', request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["update_specialist_pool"] + return self._stubs['update_specialist_pool'] -__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) +__all__ = ( + 'SpecialistPoolServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index 6d7c9ca42f..b33ec9f9b8 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -15,10 +15,18 @@ # limitations under the License. # -from .annotation import Annotation -from .annotation_spec import AnnotationSpec -from .batch_prediction_job import BatchPredictionJob -from .completion_stats import CompletionStats +from .annotation import ( + Annotation, +) +from .annotation_spec import ( + AnnotationSpec, +) +from .batch_prediction_job import ( + BatchPredictionJob, +) +from .completion_stats import ( + CompletionStats, +) from .custom_job import ( ContainerSpec, CustomJob, @@ -27,7 +35,9 @@ Scheduling, WorkerPoolSpec, ) -from .data_item import DataItem +from .data_item import ( + DataItem, +) from .data_labeling_job import ( ActiveLearningConfig, DataLabelingJob, @@ -59,8 +69,12 @@ ListDatasetsResponse, UpdateDatasetRequest, ) -from .deployed_model_ref import DeployedModelRef -from .encryption_spec import EncryptionSpec +from .deployed_model_ref import ( + DeployedModelRef, +) +from .encryption_spec import ( + EncryptionSpec, +) from .endpoint import ( DeployedModel, Endpoint, @@ -80,8 +94,12 @@ UndeployModelResponse, UpdateEndpointRequest, ) -from .env_var import EnvVar -from .hyperparameter_tuning_job import HyperparameterTuningJob +from .env_var import ( + EnvVar, +) +from .hyperparameter_tuning_job import ( + HyperparameterTuningJob, +) from .io import ( BigQueryDestination, BigQuerySource, @@ -123,8 +141,12 @@ MachineSpec, ResourcesConsumed, ) -from .manual_batch_tuning_parameters import ManualBatchTuningParameters -from .migratable_resource import MigratableResource +from .manual_batch_tuning_parameters import ( + ManualBatchTuningParameters, +) +from .migratable_resource import ( + MigratableResource, +) from .migration_service import ( BatchMigrateResourcesOperationMetadata, BatchMigrateResourcesRequest, @@ -140,8 +162,12 @@ Port, PredictSchemata, ) -from .model_evaluation import ModelEvaluation -from .model_evaluation_slice import ModelEvaluationSlice +from .model_evaluation import ( + ModelEvaluation, +) +from .model_evaluation_slice import ( + ModelEvaluationSlice, +) from .model_service import ( DeleteModelRequest, ExportModelOperationMetadata, @@ -177,7 +203,9 @@ PredictRequest, PredictResponse, ) -from .specialist_pool import SpecialistPool +from .specialist_pool import ( + SpecialistPool, +) from .specialist_pool_service import ( CreateSpecialistPoolOperationMetadata, CreateSpecialistPoolRequest, @@ -201,161 +229,163 @@ TimestampSplit, TrainingPipeline, ) -from .user_action_reference import UserActionReference +from .user_action_reference import ( + UserActionReference, +) __all__ = ( - "AcceleratorType", - "Annotation", - "AnnotationSpec", - "BatchPredictionJob", - "CompletionStats", - "ContainerSpec", - "CustomJob", - "CustomJobSpec", - "PythonPackageSpec", - "Scheduling", - "WorkerPoolSpec", - "DataItem", - "ActiveLearningConfig", - "DataLabelingJob", - "SampleConfig", - "TrainingConfig", - "Dataset", - "ExportDataConfig", - "ImportDataConfig", - "CreateDatasetOperationMetadata", - "CreateDatasetRequest", - "DeleteDatasetRequest", - "ExportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "GetAnnotationSpecRequest", - "GetDatasetRequest", - "ImportDataOperationMetadata", - "ImportDataRequest", - "ImportDataResponse", - "ListAnnotationsRequest", - "ListAnnotationsResponse", - "ListDataItemsRequest", - "ListDataItemsResponse", - "ListDatasetsRequest", - "ListDatasetsResponse", - "UpdateDatasetRequest", - "DeployedModelRef", - "EncryptionSpec", - "DeployedModel", - "Endpoint", - "CreateEndpointOperationMetadata", - "CreateEndpointRequest", - "DeleteEndpointRequest", - "DeployModelOperationMetadata", - "DeployModelRequest", - "DeployModelResponse", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UndeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UpdateEndpointRequest", - "EnvVar", - "HyperparameterTuningJob", - "BigQueryDestination", - "BigQuerySource", - "ContainerRegistryDestination", - "GcsDestination", - "GcsSource", - "CancelBatchPredictionJobRequest", - "CancelCustomJobRequest", - "CancelDataLabelingJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "CreateCustomJobRequest", - "CreateDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "DeleteBatchPredictionJobRequest", - "DeleteCustomJobRequest", - "DeleteDataLabelingJobRequest", - "DeleteHyperparameterTuningJobRequest", - "GetBatchPredictionJobRequest", - "GetCustomJobRequest", - "GetDataLabelingJobRequest", - "GetHyperparameterTuningJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "JobState", - "AutomaticResources", - "BatchDedicatedResources", - "DedicatedResources", - "DiskSpec", - "MachineSpec", - "ResourcesConsumed", - "ManualBatchTuningParameters", - "MigratableResource", - "BatchMigrateResourcesOperationMetadata", - "BatchMigrateResourcesRequest", - "BatchMigrateResourcesResponse", - "MigrateResourceRequest", - "MigrateResourceResponse", - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", - "Model", - "ModelContainerSpec", - "Port", - "PredictSchemata", - "ModelEvaluation", - "ModelEvaluationSlice", - "DeleteModelRequest", - "ExportModelOperationMetadata", - "ExportModelRequest", - "ExportModelResponse", - "GetModelEvaluationRequest", - "GetModelEvaluationSliceRequest", - "GetModelRequest", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", - "UploadModelOperationMetadata", - "UploadModelRequest", - "UploadModelResponse", - "DeleteOperationMetadata", - "GenericOperationMetadata", - "CancelTrainingPipelineRequest", - "CreateTrainingPipelineRequest", - "DeleteTrainingPipelineRequest", - "GetTrainingPipelineRequest", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "PipelineState", - "PredictRequest", - "PredictResponse", - "SpecialistPool", - "CreateSpecialistPoolOperationMetadata", - "CreateSpecialistPoolRequest", - "DeleteSpecialistPoolRequest", - "GetSpecialistPoolRequest", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "UpdateSpecialistPoolOperationMetadata", - "UpdateSpecialistPoolRequest", - "Measurement", - "StudySpec", - "Trial", - "FilterSplit", - "FractionSplit", - "InputDataConfig", - "PredefinedSplit", - "TimestampSplit", - "TrainingPipeline", - "UserActionReference", + 'AcceleratorType', + 'Annotation', + 'AnnotationSpec', + 'BatchPredictionJob', + 'CompletionStats', + 'ContainerSpec', + 'CustomJob', + 'CustomJobSpec', + 'PythonPackageSpec', + 'Scheduling', + 'WorkerPoolSpec', + 'DataItem', + 'ActiveLearningConfig', + 'DataLabelingJob', + 'SampleConfig', + 'TrainingConfig', + 'Dataset', + 'ExportDataConfig', + 'ImportDataConfig', + 'CreateDatasetOperationMetadata', + 'CreateDatasetRequest', + 'DeleteDatasetRequest', + 'ExportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'GetAnnotationSpecRequest', + 'GetDatasetRequest', + 'ImportDataOperationMetadata', + 'ImportDataRequest', + 'ImportDataResponse', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'UpdateDatasetRequest', + 'DeployedModelRef', + 'EncryptionSpec', + 'DeployedModel', + 'Endpoint', + 'CreateEndpointOperationMetadata', + 'CreateEndpointRequest', + 'DeleteEndpointRequest', + 'DeployModelOperationMetadata', + 'DeployModelRequest', + 'DeployModelResponse', + 'GetEndpointRequest', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'UndeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UpdateEndpointRequest', + 'EnvVar', + 'HyperparameterTuningJob', + 'BigQueryDestination', + 'BigQuerySource', + 'ContainerRegistryDestination', + 'GcsDestination', + 'GcsSource', + 'CancelBatchPredictionJobRequest', + 'CancelCustomJobRequest', + 'CancelDataLabelingJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CreateBatchPredictionJobRequest', + 'CreateCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'CreateHyperparameterTuningJobRequest', + 'DeleteBatchPredictionJobRequest', + 'DeleteCustomJobRequest', + 'DeleteDataLabelingJobRequest', + 'DeleteHyperparameterTuningJobRequest', + 'GetBatchPredictionJobRequest', + 'GetCustomJobRequest', + 'GetDataLabelingJobRequest', + 'GetHyperparameterTuningJobRequest', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'JobState', + 'AutomaticResources', + 'BatchDedicatedResources', + 'DedicatedResources', + 'DiskSpec', + 'MachineSpec', + 'ResourcesConsumed', + 'ManualBatchTuningParameters', + 'MigratableResource', + 'BatchMigrateResourcesOperationMetadata', + 'BatchMigrateResourcesRequest', + 'BatchMigrateResourcesResponse', + 'MigrateResourceRequest', + 'MigrateResourceResponse', + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'Model', + 'ModelContainerSpec', + 'Port', + 'PredictSchemata', + 'ModelEvaluation', + 'ModelEvaluationSlice', + 'DeleteModelRequest', + 'ExportModelOperationMetadata', + 'ExportModelRequest', + 'ExportModelResponse', + 'GetModelEvaluationRequest', + 'GetModelEvaluationSliceRequest', + 'GetModelRequest', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'ListModelsRequest', + 'ListModelsResponse', + 'UpdateModelRequest', + 'UploadModelOperationMetadata', + 'UploadModelRequest', + 'UploadModelResponse', + 'DeleteOperationMetadata', + 'GenericOperationMetadata', + 'CancelTrainingPipelineRequest', + 'CreateTrainingPipelineRequest', + 'DeleteTrainingPipelineRequest', + 'GetTrainingPipelineRequest', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'PipelineState', + 'PredictRequest', + 'PredictResponse', + 'SpecialistPool', + 'CreateSpecialistPoolOperationMetadata', + 'CreateSpecialistPoolRequest', + 'DeleteSpecialistPoolRequest', + 'GetSpecialistPoolRequest', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'UpdateSpecialistPoolOperationMetadata', + 'UpdateSpecialistPoolRequest', + 'Measurement', + 'StudySpec', + 'Trial', + 'FilterSplit', + 'FractionSplit', + 'InputDataConfig', + 'PredefinedSplit', + 'TimestampSplit', + 'TrainingPipeline', + 'UserActionReference', ) diff --git a/google/cloud/aiplatform_v1/types/accelerator_type.py b/google/cloud/aiplatform_v1/types/accelerator_type.py index 640436c38c..b22abd8ffb 100644 --- a/google/cloud/aiplatform_v1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1/types/accelerator_type.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"AcceleratorType",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'AcceleratorType', + }, ) diff --git a/google/cloud/aiplatform_v1/types/annotation.py b/google/cloud/aiplatform_v1/types/annotation.py index 000ca49dcb..3a08c3dead 100644 --- a/google/cloud/aiplatform_v1/types/annotation.py +++ b/google/cloud/aiplatform_v1/types/annotation.py @@ -24,7 +24,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"Annotation",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'Annotation', + }, ) @@ -91,16 +94,22 @@ class Annotation(proto.Message): payload_schema_uri = proto.Field(proto.STRING, number=2) - payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + payload = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) etag = proto.Field(proto.STRING, number=8) - annotation_source = proto.Field( - proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, + annotation_source = proto.Field(proto.MESSAGE, number=5, + message=user_action_reference.UserActionReference, ) labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1/types/annotation_spec.py b/google/cloud/aiplatform_v1/types/annotation_spec.py index 41f228ad72..4bcd10d1ba 100644 --- a/google/cloud/aiplatform_v1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1/types/annotation_spec.py @@ -22,7 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"AnnotationSpec",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'AnnotationSpec', + }, ) @@ -55,9 +58,13 @@ class AnnotationSpec(proto.Message): display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1/types/batch_prediction_job.py index d2d8f02203..a75a861570 100644 --- a/google/cloud/aiplatform_v1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1/types/batch_prediction_job.py @@ -23,16 +23,17 @@ from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import job_state from google.cloud.aiplatform_v1.types import machine_resources -from google.cloud.aiplatform_v1.types import ( - manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, -) +from google.cloud.aiplatform_v1.types import manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"BatchPredictionJob",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'BatchPredictionJob', + }, ) @@ -147,7 +148,6 @@ class BatchPredictionJob(proto.Message): resources created by the BatchPredictionJob will be encrypted with the provided encryption key. """ - class InputConfig(proto.Message): r"""Configures the input to ``BatchPredictionJob``. @@ -174,12 +174,12 @@ class InputConfig(proto.Message): ``supported_input_storage_formats``. """ - gcs_source = proto.Field( - proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, + gcs_source = proto.Field(proto.MESSAGE, number=2, oneof='source', + message=io.GcsSource, ) - bigquery_source = proto.Field( - proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, + bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', + message=io.BigQuerySource, ) instances_format = proto.Field(proto.STRING, number=1) @@ -250,14 +250,11 @@ class OutputConfig(proto.Message): ``supported_output_storage_formats``. """ - gcs_destination = proto.Field( - proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, + gcs_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', + message=io.GcsDestination, ) - bigquery_destination = proto.Field( - proto.MESSAGE, - number=3, - oneof="destination", + bigquery_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', message=io.BigQueryDestination, ) @@ -278,13 +275,9 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field( - proto.STRING, number=1, oneof="output_location" - ) + gcs_output_directory = proto.Field(proto.STRING, number=1, oneof='output_location') - bigquery_output_dataset = proto.Field( - proto.STRING, number=2, oneof="output_location" - ) + bigquery_output_dataset = proto.Field(proto.STRING, number=2, oneof='output_location') name = proto.Field(proto.STRING, number=1) @@ -292,52 +285,70 @@ class OutputInfo(proto.Message): model = proto.Field(proto.STRING, number=3) - input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) + input_config = proto.Field(proto.MESSAGE, number=4, + message=InputConfig, + ) - model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + model_parameters = proto.Field(proto.MESSAGE, number=5, + message=struct.Value, + ) - output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) + output_config = proto.Field(proto.MESSAGE, number=6, + message=OutputConfig, + ) - dedicated_resources = proto.Field( - proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, + dedicated_resources = proto.Field(proto.MESSAGE, number=7, + message=machine_resources.BatchDedicatedResources, ) - manual_batch_tuning_parameters = proto.Field( - proto.MESSAGE, - number=8, + manual_batch_tuning_parameters = proto.Field(proto.MESSAGE, number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) - output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) + output_info = proto.Field(proto.MESSAGE, number=9, + message=OutputInfo, + ) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=10, + enum=job_state.JobState, + ) - error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=11, + message=status.Status, + ) - partial_failures = proto.RepeatedField( - proto.MESSAGE, number=12, message=status.Status, + partial_failures = proto.RepeatedField(proto.MESSAGE, number=12, + message=status.Status, ) - resources_consumed = proto.Field( - proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, + resources_consumed = proto.Field(proto.MESSAGE, number=13, + message=machine_resources.ResourcesConsumed, ) - completion_stats = proto.Field( - proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, + completion_stats = proto.Field(proto.MESSAGE, number=14, + message=gca_completion_stats.CompletionStats, ) - create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=15, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=16, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=17, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=18, + message=timestamp.Timestamp, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=19) - encryption_spec = proto.Field( - proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=24, + message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1/types/completion_stats.py b/google/cloud/aiplatform_v1/types/completion_stats.py index 05648d82c4..8a0f151024 100644 --- a/google/cloud/aiplatform_v1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1/types/completion_stats.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"CompletionStats",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'CompletionStats', + }, ) diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py index c97cba6d82..176e077042 100644 --- a/google/cloud/aiplatform_v1/types/custom_job.py +++ b/google/cloud/aiplatform_v1/types/custom_job.py @@ -29,14 +29,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "CustomJob", - "CustomJobSpec", - "WorkerPoolSpec", - "ContainerSpec", - "PythonPackageSpec", - "Scheduling", + 'CustomJob', + 'CustomJobSpec', + 'WorkerPoolSpec', + 'ContainerSpec', + 'PythonPackageSpec', + 'Scheduling', }, ) @@ -96,24 +96,38 @@ class CustomJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) + job_spec = proto.Field(proto.MESSAGE, number=4, + message='CustomJobSpec', + ) - state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=5, + enum=job_state.JobState, + ) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=10, + message=status.Status, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=11) - encryption_spec = proto.Field( - proto.MESSAGE, number=12, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=12, + message=gca_encryption_spec.EncryptionSpec, ) @@ -176,18 +190,20 @@ class CustomJobSpec(proto.Message): ``//logs/`` """ - worker_pool_specs = proto.RepeatedField( - proto.MESSAGE, number=1, message="WorkerPoolSpec", + worker_pool_specs = proto.RepeatedField(proto.MESSAGE, number=1, + message='WorkerPoolSpec', ) - scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) + scheduling = proto.Field(proto.MESSAGE, number=3, + message='Scheduling', + ) service_account = proto.Field(proto.STRING, number=4) network = proto.Field(proto.STRING, number=5) - base_output_directory = proto.Field( - proto.MESSAGE, number=6, message=io.GcsDestination, + base_output_directory = proto.Field(proto.MESSAGE, number=6, + message=io.GcsDestination, ) @@ -209,22 +225,22 @@ class WorkerPoolSpec(proto.Message): Disk spec. """ - container_spec = proto.Field( - proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", + container_spec = proto.Field(proto.MESSAGE, number=6, oneof='task', + message='ContainerSpec', ) - python_package_spec = proto.Field( - proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", + python_package_spec = proto.Field(proto.MESSAGE, number=7, oneof='task', + message='PythonPackageSpec', ) - machine_spec = proto.Field( - proto.MESSAGE, number=1, message=machine_resources.MachineSpec, + machine_spec = proto.Field(proto.MESSAGE, number=1, + message=machine_resources.MachineSpec, ) replica_count = proto.Field(proto.INT64, number=2) - disk_spec = proto.Field( - proto.MESSAGE, number=5, message=machine_resources.DiskSpec, + disk_spec = proto.Field(proto.MESSAGE, number=5, + message=machine_resources.DiskSpec, ) @@ -254,7 +270,9 @@ class ContainerSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) + env = proto.RepeatedField(proto.MESSAGE, number=4, + message=env_var.EnvVar, + ) class PythonPackageSpec(proto.Message): @@ -292,7 +310,9 @@ class PythonPackageSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=4) - env = proto.RepeatedField(proto.MESSAGE, number=5, message=env_var.EnvVar,) + env = proto.RepeatedField(proto.MESSAGE, number=5, + message=env_var.EnvVar, + ) class Scheduling(proto.Message): @@ -310,7 +330,9 @@ class Scheduling(proto.Message): to workers leaving and joining a job. """ - timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + timeout = proto.Field(proto.MESSAGE, number=1, + message=duration.Duration, + ) restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1/types/data_item.py b/google/cloud/aiplatform_v1/types/data_item.py index 20ff14a0d8..d29e056d16 100644 --- a/google/cloud/aiplatform_v1/types/data_item.py +++ b/google/cloud/aiplatform_v1/types/data_item.py @@ -23,7 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"DataItem",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'DataItem', + }, ) @@ -70,13 +73,19 @@ class DataItem(proto.Message): name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=2, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=3) - payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + payload = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1/types/data_labeling_job.py b/google/cloud/aiplatform_v1/types/data_labeling_job.py index e1058737bf..8caca23d09 100644 --- a/google/cloud/aiplatform_v1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1/types/data_labeling_job.py @@ -27,12 +27,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "DataLabelingJob", - "ActiveLearningConfig", - "SampleConfig", - "TrainingConfig", + 'DataLabelingJob', + 'ActiveLearningConfig', + 'SampleConfig', + 'TrainingConfig', }, ) @@ -154,30 +154,42 @@ class DataLabelingJob(proto.Message): inputs_schema_uri = proto.Field(proto.STRING, number=6) - inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) + inputs = proto.Field(proto.MESSAGE, number=7, + message=struct.Value, + ) - state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=8, + enum=job_state.JobState, + ) labeling_progress = proto.Field(proto.INT32, number=13) - current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) + current_spend = proto.Field(proto.MESSAGE, number=14, + message=money.Money, + ) - create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=10, + message=timestamp.Timestamp, + ) - error = proto.Field(proto.MESSAGE, number=22, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=22, + message=status.Status, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=11) specialist_pools = proto.RepeatedField(proto.STRING, number=16) - encryption_spec = proto.Field( - proto.MESSAGE, number=20, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=20, + message=gca_encryption_spec.EncryptionSpec, ) - active_learning_config = proto.Field( - proto.MESSAGE, number=21, message="ActiveLearningConfig", + active_learning_config = proto.Field(proto.MESSAGE, number=21, + message='ActiveLearningConfig', ) @@ -206,17 +218,17 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field( - proto.INT64, number=1, oneof="human_labeling_budget" - ) + max_data_item_count = proto.Field(proto.INT64, number=1, oneof='human_labeling_budget') - max_data_item_percentage = proto.Field( - proto.INT32, number=2, oneof="human_labeling_budget" - ) + max_data_item_percentage = proto.Field(proto.INT32, number=2, oneof='human_labeling_budget') - sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) + sample_config = proto.Field(proto.MESSAGE, number=3, + message='SampleConfig', + ) - training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) + training_config = proto.Field(proto.MESSAGE, number=4, + message='TrainingConfig', + ) class SampleConfig(proto.Message): @@ -237,7 +249,6 @@ class SampleConfig(proto.Message): strategy will decide which data should be selected for human labeling in every batch. """ - class SampleStrategy(proto.Enum): r"""Sample strategy decides which subset of DataItems should be selected for human labeling in every batch. @@ -245,15 +256,13 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field( - proto.INT32, number=1, oneof="initial_batch_sample_size" - ) + initial_batch_sample_percentage = proto.Field(proto.INT32, number=1, oneof='initial_batch_sample_size') - following_batch_sample_percentage = proto.Field( - proto.INT32, number=3, oneof="following_batch_sample_size" - ) + following_batch_sample_percentage = proto.Field(proto.INT32, number=3, oneof='following_batch_sample_size') - sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) + sample_strategy = proto.Field(proto.ENUM, number=5, + enum=SampleStrategy, + ) class TrainingConfig(proto.Message): diff --git a/google/cloud/aiplatform_v1/types/dataset.py b/google/cloud/aiplatform_v1/types/dataset.py index 2f75dce0d5..29e205f9c4 100644 --- a/google/cloud/aiplatform_v1/types/dataset.py +++ b/google/cloud/aiplatform_v1/types/dataset.py @@ -25,8 +25,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", - manifest={"Dataset", "ImportDataConfig", "ExportDataConfig",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'Dataset', + 'ImportDataConfig', + 'ExportDataConfig', + }, ) @@ -94,18 +98,24 @@ class Dataset(proto.Message): metadata_schema_uri = proto.Field(proto.STRING, number=3) - metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) + metadata = proto.Field(proto.MESSAGE, number=8, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) etag = proto.Field(proto.STRING, number=6) labels = proto.MapField(proto.STRING, proto.STRING, number=7) - encryption_spec = proto.Field( - proto.MESSAGE, number=11, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=11, + message=gca_encryption_spec.EncryptionSpec, ) @@ -141,8 +151,8 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field( - proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, + gcs_source = proto.Field(proto.MESSAGE, number=1, oneof='source', + message=io.GcsSource, ) data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) @@ -175,8 +185,8 @@ class ExportDataConfig(proto.Message): ``ListAnnotations``. """ - gcs_destination = proto.Field( - proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, + gcs_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', + message=io.GcsDestination, ) annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/dataset_service.py b/google/cloud/aiplatform_v1/types/dataset_service.py index ccc8cce600..1991dd02ec 100644 --- a/google/cloud/aiplatform_v1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1/types/dataset_service.py @@ -26,26 +26,26 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "CreateDatasetRequest", - "CreateDatasetOperationMetadata", - "GetDatasetRequest", - "UpdateDatasetRequest", - "ListDatasetsRequest", - "ListDatasetsResponse", - "DeleteDatasetRequest", - "ImportDataRequest", - "ImportDataResponse", - "ImportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportDataOperationMetadata", - "ListDataItemsRequest", - "ListDataItemsResponse", - "GetAnnotationSpecRequest", - "ListAnnotationsRequest", - "ListAnnotationsResponse", + 'CreateDatasetRequest', + 'CreateDatasetOperationMetadata', + 'GetDatasetRequest', + 'UpdateDatasetRequest', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'DeleteDatasetRequest', + 'ImportDataRequest', + 'ImportDataResponse', + 'ImportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'ExportDataOperationMetadata', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'GetAnnotationSpecRequest', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', }, ) @@ -65,7 +65,9 @@ class CreateDatasetRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) + dataset = proto.Field(proto.MESSAGE, number=2, + message=gca_dataset.Dataset, + ) class CreateDatasetOperationMetadata(proto.Message): @@ -77,8 +79,8 @@ class CreateDatasetOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -95,7 +97,9 @@ class GetDatasetRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class UpdateDatasetRequest(proto.Message): @@ -117,9 +121,13 @@ class UpdateDatasetRequest(proto.Message): - ``labels`` """ - dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) + dataset = proto.Field(proto.MESSAGE, number=1, + message=gca_dataset.Dataset, + ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class ListDatasetsRequest(proto.Message): @@ -171,7 +179,9 @@ class ListDatasetsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -192,8 +202,8 @@ class ListDatasetsResponse(proto.Message): def raw_page(self): return self - datasets = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_dataset.Dataset, + datasets = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_dataset.Dataset, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -229,8 +239,8 @@ class ImportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - import_configs = proto.RepeatedField( - proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, + import_configs = proto.RepeatedField(proto.MESSAGE, number=2, + message=gca_dataset.ImportDataConfig, ) @@ -249,8 +259,8 @@ class ImportDataOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -268,8 +278,8 @@ class ExportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - export_config = proto.Field( - proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, + export_config = proto.Field(proto.MESSAGE, number=2, + message=gca_dataset.ExportDataConfig, ) @@ -299,8 +309,8 @@ class ExportDataOperationMetadata(proto.Message): the directory. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -337,7 +347,9 @@ class ListDataItemsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -358,8 +370,8 @@ class ListDataItemsResponse(proto.Message): def raw_page(self): return self - data_items = proto.RepeatedField( - proto.MESSAGE, number=1, message=data_item.DataItem, + data_items = proto.RepeatedField(proto.MESSAGE, number=1, + message=data_item.DataItem, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -380,7 +392,9 @@ class GetAnnotationSpecRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class ListAnnotationsRequest(proto.Message): @@ -415,7 +429,9 @@ class ListAnnotationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -436,8 +452,8 @@ class ListAnnotationsResponse(proto.Message): def raw_page(self): return self - annotations = proto.RepeatedField( - proto.MESSAGE, number=1, message=annotation.Annotation, + annotations = proto.RepeatedField(proto.MESSAGE, number=1, + message=annotation.Annotation, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1/types/deployed_model_ref.py index 2d53610ed5..ffd0e4182d 100644 --- a/google/cloud/aiplatform_v1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1/types/deployed_model_ref.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"DeployedModelRef",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'DeployedModelRef', + }, ) diff --git a/google/cloud/aiplatform_v1/types/encryption_spec.py b/google/cloud/aiplatform_v1/types/encryption_spec.py index ae908d4b72..a87a91a91e 100644 --- a/google/cloud/aiplatform_v1/types/encryption_spec.py +++ b/google/cloud/aiplatform_v1/types/encryption_spec.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"EncryptionSpec",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'EncryptionSpec', + }, ) diff --git a/google/cloud/aiplatform_v1/types/endpoint.py b/google/cloud/aiplatform_v1/types/endpoint.py index 5cbe3c1b1d..cff9c6b4a7 100644 --- a/google/cloud/aiplatform_v1/types/endpoint.py +++ b/google/cloud/aiplatform_v1/types/endpoint.py @@ -24,7 +24,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"Endpoint", "DeployedModel",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'Endpoint', + 'DeployedModel', + }, ) @@ -92,8 +96,8 @@ class Endpoint(proto.Message): description = proto.Field(proto.STRING, number=3) - deployed_models = proto.RepeatedField( - proto.MESSAGE, number=4, message="DeployedModel", + deployed_models = proto.RepeatedField(proto.MESSAGE, number=4, + message='DeployedModel', ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) @@ -102,12 +106,16 @@ class Endpoint(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=7) - create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) - encryption_spec = proto.Field( - proto.MESSAGE, number=10, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=10, + message=gca_encryption_spec.EncryptionSpec, ) @@ -168,17 +176,11 @@ class DeployedModel(proto.Message): option. """ - dedicated_resources = proto.Field( - proto.MESSAGE, - number=7, - oneof="prediction_resources", + dedicated_resources = proto.Field(proto.MESSAGE, number=7, oneof='prediction_resources', message=machine_resources.DedicatedResources, ) - automatic_resources = proto.Field( - proto.MESSAGE, - number=8, - oneof="prediction_resources", + automatic_resources = proto.Field(proto.MESSAGE, number=8, oneof='prediction_resources', message=machine_resources.AutomaticResources, ) @@ -188,7 +190,9 @@ class DeployedModel(proto.Message): display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) service_account = proto.Field(proto.STRING, number=11) diff --git a/google/cloud/aiplatform_v1/types/endpoint_service.py b/google/cloud/aiplatform_v1/types/endpoint_service.py index 24e00bd486..67b893b9aa 100644 --- a/google/cloud/aiplatform_v1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1/types/endpoint_service.py @@ -24,21 +24,21 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "CreateEndpointRequest", - "CreateEndpointOperationMetadata", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UpdateEndpointRequest", - "DeleteEndpointRequest", - "DeployModelRequest", - "DeployModelResponse", - "DeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UndeployModelOperationMetadata", + 'CreateEndpointRequest', + 'CreateEndpointOperationMetadata', + 'GetEndpointRequest', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'UpdateEndpointRequest', + 'DeleteEndpointRequest', + 'DeployModelRequest', + 'DeployModelResponse', + 'DeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UndeployModelOperationMetadata', }, ) @@ -58,7 +58,9 @@ class CreateEndpointRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) + endpoint = proto.Field(proto.MESSAGE, number=2, + message=gca_endpoint.Endpoint, + ) class CreateEndpointOperationMetadata(proto.Message): @@ -70,8 +72,8 @@ class CreateEndpointOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -151,7 +153,9 @@ class ListEndpointsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -173,8 +177,8 @@ class ListEndpointsResponse(proto.Message): def raw_page(self): return self - endpoints = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, + endpoints = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_endpoint.Endpoint, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -193,9 +197,13 @@ class UpdateEndpointRequest(proto.Message): `FieldMask `__. """ - endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) + endpoint = proto.Field(proto.MESSAGE, number=1, + message=gca_endpoint.Endpoint, + ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class DeleteEndpointRequest(proto.Message): @@ -248,8 +256,8 @@ class DeployModelRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - deployed_model = proto.Field( - proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, + deployed_model = proto.Field(proto.MESSAGE, number=2, + message=gca_endpoint.DeployedModel, ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -265,8 +273,8 @@ class DeployModelResponse(proto.Message): the Endpoint. """ - deployed_model = proto.Field( - proto.MESSAGE, number=1, message=gca_endpoint.DeployedModel, + deployed_model = proto.Field(proto.MESSAGE, number=1, + message=gca_endpoint.DeployedModel, ) @@ -279,8 +287,8 @@ class DeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -329,8 +337,8 @@ class UndeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1/types/env_var.py b/google/cloud/aiplatform_v1/types/env_var.py index f456c15808..8a843cd18c 100644 --- a/google/cloud/aiplatform_v1/types/env_var.py +++ b/google/cloud/aiplatform_v1/types/env_var.py @@ -18,7 +18,12 @@ import proto # type: ignore -__protobuf__ = proto.module(package="google.cloud.aiplatform.v1", manifest={"EnvVar",},) +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1', + manifest={ + 'EnvVar', + }, +) class EnvVar(proto.Message): diff --git a/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py index 63290ff9b4..e19c94b054 100644 --- a/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py @@ -27,7 +27,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"HyperparameterTuningJob",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'HyperparameterTuningJob', + }, ) @@ -106,7 +109,9 @@ class HyperparameterTuningJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) + study_spec = proto.Field(proto.MESSAGE, number=4, + message=study.StudySpec, + ) max_trial_count = proto.Field(proto.INT32, number=5) @@ -114,28 +119,42 @@ class HyperparameterTuningJob(proto.Message): max_failed_trial_count = proto.Field(proto.INT32, number=7) - trial_job_spec = proto.Field( - proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, + trial_job_spec = proto.Field(proto.MESSAGE, number=8, + message=custom_job.CustomJobSpec, ) - trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) + trials = proto.RepeatedField(proto.MESSAGE, number=9, + message=study.Trial, + ) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=10, + enum=job_state.JobState, + ) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) - error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=15, + message=status.Status, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=16) - encryption_spec = proto.Field( - proto.MESSAGE, number=17, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=17, + message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1/types/io.py b/google/cloud/aiplatform_v1/types/io.py index 1a75ea33bc..2cf3c7b5f6 100644 --- a/google/cloud/aiplatform_v1/types/io.py +++ b/google/cloud/aiplatform_v1/types/io.py @@ -19,13 +19,13 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "GcsSource", - "GcsDestination", - "BigQuerySource", - "BigQueryDestination", - "ContainerRegistryDestination", + 'GcsSource', + 'GcsDestination', + 'BigQuerySource', + 'BigQueryDestination', + 'ContainerRegistryDestination', }, ) diff --git a/google/cloud/aiplatform_v1/types/job_service.py b/google/cloud/aiplatform_v1/types/job_service.py index 3a6d844ea7..edf28bd54b 100644 --- a/google/cloud/aiplatform_v1/types/job_service.py +++ b/google/cloud/aiplatform_v1/types/job_service.py @@ -18,44 +18,40 @@ import proto # type: ignore -from google.cloud.aiplatform_v1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job -from google.cloud.aiplatform_v1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "CreateCustomJobRequest", - "GetCustomJobRequest", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "DeleteCustomJobRequest", - "CancelCustomJobRequest", - "CreateDataLabelingJobRequest", - "GetDataLabelingJobRequest", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "DeleteDataLabelingJobRequest", - "CancelDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "GetHyperparameterTuningJobRequest", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "DeleteHyperparameterTuningJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "GetBatchPredictionJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "DeleteBatchPredictionJobRequest", - "CancelBatchPredictionJobRequest", + 'CreateCustomJobRequest', + 'GetCustomJobRequest', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'DeleteCustomJobRequest', + 'CancelCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'GetDataLabelingJobRequest', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'DeleteDataLabelingJobRequest', + 'CancelDataLabelingJobRequest', + 'CreateHyperparameterTuningJobRequest', + 'GetHyperparameterTuningJobRequest', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'DeleteHyperparameterTuningJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CreateBatchPredictionJobRequest', + 'GetBatchPredictionJobRequest', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'DeleteBatchPredictionJobRequest', + 'CancelBatchPredictionJobRequest', }, ) @@ -75,7 +71,9 @@ class CreateCustomJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) + custom_job = proto.Field(proto.MESSAGE, number=2, + message=gca_custom_job.CustomJob, + ) class GetCustomJobRequest(proto.Message): @@ -138,7 +136,9 @@ class ListCustomJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListCustomJobsResponse(proto.Message): @@ -158,8 +158,8 @@ class ListCustomJobsResponse(proto.Message): def raw_page(self): return self - custom_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, + custom_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_custom_job.CustomJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -206,8 +206,8 @@ class CreateDataLabelingJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field( - proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, + data_labeling_job = proto.Field(proto.MESSAGE, number=2, + message=gca_data_labeling_job.DataLabelingJob, ) @@ -273,7 +273,9 @@ class ListDataLabelingJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -294,8 +296,8 @@ class ListDataLabelingJobsResponse(proto.Message): def raw_page(self): return self - data_labeling_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, + data_labeling_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_data_labeling_job.DataLabelingJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -346,9 +348,7 @@ class CreateHyperparameterTuningJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - hyperparameter_tuning_job = proto.Field( - proto.MESSAGE, - number=2, + hyperparameter_tuning_job = proto.Field(proto.MESSAGE, number=2, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -415,7 +415,9 @@ class ListHyperparameterTuningJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListHyperparameterTuningJobsResponse(proto.Message): @@ -437,9 +439,7 @@ class ListHyperparameterTuningJobsResponse(proto.Message): def raw_page(self): return self - hyperparameter_tuning_jobs = proto.RepeatedField( - proto.MESSAGE, - number=1, + hyperparameter_tuning_jobs = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -491,8 +491,8 @@ class CreateBatchPredictionJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - batch_prediction_job = proto.Field( - proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_job = proto.Field(proto.MESSAGE, number=2, + message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -558,7 +558,9 @@ class ListBatchPredictionJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListBatchPredictionJobsResponse(proto.Message): @@ -579,8 +581,8 @@ class ListBatchPredictionJobsResponse(proto.Message): def raw_page(self): return self - batch_prediction_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_batch_prediction_job.BatchPredictionJob, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/job_state.py b/google/cloud/aiplatform_v1/types/job_state.py index 40b1694f86..5ca5147c2c 100644 --- a/google/cloud/aiplatform_v1/types/job_state.py +++ b/google/cloud/aiplatform_v1/types/job_state.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"JobState",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'JobState', + }, ) diff --git a/google/cloud/aiplatform_v1/types/machine_resources.py b/google/cloud/aiplatform_v1/types/machine_resources.py index f6864eb798..a5e8209b0f 100644 --- a/google/cloud/aiplatform_v1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1/types/machine_resources.py @@ -22,14 +22,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "MachineSpec", - "DedicatedResources", - "AutomaticResources", - "BatchDedicatedResources", - "ResourcesConsumed", - "DiskSpec", + 'MachineSpec', + 'DedicatedResources', + 'AutomaticResources', + 'BatchDedicatedResources', + 'ResourcesConsumed', + 'DiskSpec', }, ) @@ -64,8 +64,8 @@ class MachineSpec(proto.Message): machine_type = proto.Field(proto.STRING, number=1) - accelerator_type = proto.Field( - proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, + accelerator_type = proto.Field(proto.ENUM, number=2, + enum=gca_accelerator_type.AcceleratorType, ) accelerator_count = proto.Field(proto.INT32, number=3) @@ -104,7 +104,9 @@ class DedicatedResources(proto.Message): as the default value. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + machine_spec = proto.Field(proto.MESSAGE, number=1, + message='MachineSpec', + ) min_replica_count = proto.Field(proto.INT32, number=2) @@ -168,7 +170,9 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + machine_spec = proto.Field(proto.MESSAGE, number=1, + message='MachineSpec', + ) starting_replica_count = proto.Field(proto.INT32, number=2) diff --git a/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py b/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py index 7500d618a0..07abcc8f01 100644 --- a/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py +++ b/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"ManualBatchTuningParameters",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'ManualBatchTuningParameters', + }, ) diff --git a/google/cloud/aiplatform_v1/types/migratable_resource.py b/google/cloud/aiplatform_v1/types/migratable_resource.py index 652a835c89..0b73b10a22 100644 --- a/google/cloud/aiplatform_v1/types/migratable_resource.py +++ b/google/cloud/aiplatform_v1/types/migratable_resource.py @@ -22,7 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"MigratableResource",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'MigratableResource', + }, ) @@ -52,7 +55,6 @@ class MigratableResource(proto.Message): Output only. Timestamp when this MigratableResource was last updated. """ - class MlEngineModelVersion(proto.Message): r"""Represents one model Version in ml.googleapis.com. @@ -121,7 +123,6 @@ class DataLabelingDataset(proto.Message): datalabeling.googleapis.com belongs to the data labeling Dataset. """ - class DataLabelingAnnotatedDataset(proto.Message): r"""Represents one AnnotatedDataset in datalabeling.googleapis.com. @@ -145,34 +146,32 @@ class DataLabelingAnnotatedDataset(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=4) - data_labeling_annotated_datasets = proto.RepeatedField( - proto.MESSAGE, - number=3, - message="MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset", + data_labeling_annotated_datasets = proto.RepeatedField(proto.MESSAGE, number=3, + message='MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset', ) - ml_engine_model_version = proto.Field( - proto.MESSAGE, number=1, oneof="resource", message=MlEngineModelVersion, + ml_engine_model_version = proto.Field(proto.MESSAGE, number=1, oneof='resource', + message=MlEngineModelVersion, ) - automl_model = proto.Field( - proto.MESSAGE, number=2, oneof="resource", message=AutomlModel, + automl_model = proto.Field(proto.MESSAGE, number=2, oneof='resource', + message=AutomlModel, ) - automl_dataset = proto.Field( - proto.MESSAGE, number=3, oneof="resource", message=AutomlDataset, + automl_dataset = proto.Field(proto.MESSAGE, number=3, oneof='resource', + message=AutomlDataset, ) - data_labeling_dataset = proto.Field( - proto.MESSAGE, number=4, oneof="resource", message=DataLabelingDataset, + data_labeling_dataset = proto.Field(proto.MESSAGE, number=4, oneof='resource', + message=DataLabelingDataset, ) - last_migrate_time = proto.Field( - proto.MESSAGE, number=5, message=timestamp.Timestamp, + last_migrate_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, ) - last_update_time = proto.Field( - proto.MESSAGE, number=6, message=timestamp.Timestamp, + last_update_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, ) diff --git a/google/cloud/aiplatform_v1/types/migration_service.py b/google/cloud/aiplatform_v1/types/migration_service.py index acd69b37b4..d608620577 100644 --- a/google/cloud/aiplatform_v1/types/migration_service.py +++ b/google/cloud/aiplatform_v1/types/migration_service.py @@ -18,23 +18,21 @@ import proto # type: ignore -from google.cloud.aiplatform_v1.types import ( - migratable_resource as gca_migratable_resource, -) +from google.cloud.aiplatform_v1.types import migratable_resource as gca_migratable_resource from google.cloud.aiplatform_v1.types import operation from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", - "BatchMigrateResourcesRequest", - "MigrateResourceRequest", - "BatchMigrateResourcesResponse", - "MigrateResourceResponse", - "BatchMigrateResourcesOperationMetadata", + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'BatchMigrateResourcesRequest', + 'MigrateResourceRequest', + 'BatchMigrateResourcesResponse', + 'MigrateResourceResponse', + 'BatchMigrateResourcesOperationMetadata', }, ) @@ -101,8 +99,8 @@ class SearchMigratableResourcesResponse(proto.Message): def raw_page(self): return self - migratable_resources = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_migratable_resource.MigratableResource, + migratable_resources = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_migratable_resource.MigratableResource, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -125,8 +123,8 @@ class BatchMigrateResourcesRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - migrate_resource_requests = proto.RepeatedField( - proto.MESSAGE, number=2, message="MigrateResourceRequest", + migrate_resource_requests = proto.RepeatedField(proto.MESSAGE, number=2, + message='MigrateResourceRequest', ) @@ -150,7 +148,6 @@ class MigrateResourceRequest(proto.Message): datalabeling.googleapis.com to AI Platform's Dataset. """ - class MigrateMlEngineModelVersionConfig(proto.Message): r"""Config for migrating version in ml.googleapis.com to AI Platform's Model. @@ -238,7 +235,6 @@ class MigrateDataLabelingDatasetConfig(proto.Message): AnnotatedDatasets have to belong to the datalabeling Dataset. """ - class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): r"""Config for migrating AnnotatedDataset in datalabeling.googleapis.com to AI Platform's SavedQuery. @@ -257,31 +253,23 @@ class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=2) - migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField( - proto.MESSAGE, - number=3, - message="MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig", + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField(proto.MESSAGE, number=3, + message='MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig', ) - migrate_ml_engine_model_version_config = proto.Field( - proto.MESSAGE, - number=1, - oneof="request", + migrate_ml_engine_model_version_config = proto.Field(proto.MESSAGE, number=1, oneof='request', message=MigrateMlEngineModelVersionConfig, ) - migrate_automl_model_config = proto.Field( - proto.MESSAGE, number=2, oneof="request", message=MigrateAutomlModelConfig, + migrate_automl_model_config = proto.Field(proto.MESSAGE, number=2, oneof='request', + message=MigrateAutomlModelConfig, ) - migrate_automl_dataset_config = proto.Field( - proto.MESSAGE, number=3, oneof="request", message=MigrateAutomlDatasetConfig, + migrate_automl_dataset_config = proto.Field(proto.MESSAGE, number=3, oneof='request', + message=MigrateAutomlDatasetConfig, ) - migrate_data_labeling_dataset_config = proto.Field( - proto.MESSAGE, - number=4, - oneof="request", + migrate_data_labeling_dataset_config = proto.Field(proto.MESSAGE, number=4, oneof='request', message=MigrateDataLabelingDatasetConfig, ) @@ -295,8 +283,8 @@ class BatchMigrateResourcesResponse(proto.Message): Successfully migrated resources. """ - migrate_resource_responses = proto.RepeatedField( - proto.MESSAGE, number=1, message="MigrateResourceResponse", + migrate_resource_responses = proto.RepeatedField(proto.MESSAGE, number=1, + message='MigrateResourceResponse', ) @@ -314,12 +302,12 @@ class MigrateResourceResponse(proto.Message): datalabeling.googleapis.com. """ - dataset = proto.Field(proto.STRING, number=1, oneof="migrated_resource") + dataset = proto.Field(proto.STRING, number=1, oneof='migrated_resource') - model = proto.Field(proto.STRING, number=2, oneof="migrated_resource") + model = proto.Field(proto.STRING, number=2, oneof='migrated_resource') - migratable_resource = proto.Field( - proto.MESSAGE, number=3, message=gca_migratable_resource.MigratableResource, + migratable_resource = proto.Field(proto.MESSAGE, number=3, + message=gca_migratable_resource.MigratableResource, ) @@ -334,7 +322,6 @@ class BatchMigrateResourcesOperationMetadata(proto.Message): Partial results that reflect the latest migration operation progress. """ - class PartialResult(proto.Message): r"""Represents a partial result in batch migration operation for one ``MigrateResourceRequest``. @@ -352,24 +339,24 @@ class PartialResult(proto.Message): [MigrateResourceRequest.migrate_resource_requests][]. """ - error = proto.Field( - proto.MESSAGE, number=2, oneof="result", message=status.Status, + error = proto.Field(proto.MESSAGE, number=2, oneof='result', + message=status.Status, ) - model = proto.Field(proto.STRING, number=3, oneof="result") + model = proto.Field(proto.STRING, number=3, oneof='result') - dataset = proto.Field(proto.STRING, number=4, oneof="result") + dataset = proto.Field(proto.STRING, number=4, oneof='result') - request = proto.Field( - proto.MESSAGE, number=1, message="MigrateResourceRequest", + request = proto.Field(proto.MESSAGE, number=1, + message='MigrateResourceRequest', ) - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) - partial_results = proto.RepeatedField( - proto.MESSAGE, number=2, message=PartialResult, + partial_results = proto.RepeatedField(proto.MESSAGE, number=2, + message=PartialResult, ) diff --git a/google/cloud/aiplatform_v1/types/model.py b/google/cloud/aiplatform_v1/types/model.py index c2db797b98..b830ba86da 100644 --- a/google/cloud/aiplatform_v1/types/model.py +++ b/google/cloud/aiplatform_v1/types/model.py @@ -26,8 +26,13 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", - manifest={"Model", "PredictSchemata", "ModelContainerSpec", "Port",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'Model', + 'PredictSchemata', + 'ModelContainerSpec', + 'Port', + }, ) @@ -213,7 +218,6 @@ class Model(proto.Message): Model. If set, this Model and all sub-resources of this Model will be secured by this key. """ - class DeploymentResourcesType(proto.Enum): r"""Identifies a type of Model's prediction resources.""" DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = 0 @@ -250,7 +254,6 @@ class ExportFormat(proto.Message): Output only. The content of this Model that may be exported. """ - class ExportableContent(proto.Enum): r"""The Model content that can be exported.""" EXPORTABLE_CONTENT_UNSPECIFIED = 0 @@ -259,8 +262,8 @@ class ExportableContent(proto.Enum): id = proto.Field(proto.STRING, number=1) - exportable_contents = proto.RepeatedField( - proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", + exportable_contents = proto.RepeatedField(proto.ENUM, number=2, + enum='Model.ExportFormat.ExportableContent', ) name = proto.Field(proto.STRING, number=1) @@ -269,44 +272,54 @@ class ExportableContent(proto.Enum): description = proto.Field(proto.STRING, number=3) - predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) + predict_schemata = proto.Field(proto.MESSAGE, number=4, + message='PredictSchemata', + ) metadata_schema_uri = proto.Field(proto.STRING, number=5) - metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + metadata = proto.Field(proto.MESSAGE, number=6, + message=struct.Value, + ) - supported_export_formats = proto.RepeatedField( - proto.MESSAGE, number=20, message=ExportFormat, + supported_export_formats = proto.RepeatedField(proto.MESSAGE, number=20, + message=ExportFormat, ) training_pipeline = proto.Field(proto.STRING, number=7) - container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) + container_spec = proto.Field(proto.MESSAGE, number=9, + message='ModelContainerSpec', + ) artifact_uri = proto.Field(proto.STRING, number=26) - supported_deployment_resources_types = proto.RepeatedField( - proto.ENUM, number=10, enum=DeploymentResourcesType, + supported_deployment_resources_types = proto.RepeatedField(proto.ENUM, number=10, + enum=DeploymentResourcesType, ) supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) - create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) - deployed_models = proto.RepeatedField( - proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, + deployed_models = proto.RepeatedField(proto.MESSAGE, number=15, + message=deployed_model_ref.DeployedModelRef, ) etag = proto.Field(proto.STRING, number=16) labels = proto.MapField(proto.STRING, proto.STRING, number=17) - encryption_spec = proto.Field( - proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=24, + message=gca_encryption_spec.EncryptionSpec, ) @@ -605,9 +618,13 @@ class ModelContainerSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) + env = proto.RepeatedField(proto.MESSAGE, number=4, + message=env_var.EnvVar, + ) - ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) + ports = proto.RepeatedField(proto.MESSAGE, number=5, + message='Port', + ) predict_route = proto.Field(proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1/types/model_evaluation.py b/google/cloud/aiplatform_v1/types/model_evaluation.py index f617f3d197..08bafad024 100644 --- a/google/cloud/aiplatform_v1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1/types/model_evaluation.py @@ -23,7 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"ModelEvaluation",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'ModelEvaluation', + }, ) @@ -63,9 +66,13 @@ class ModelEvaluation(proto.Message): metrics_schema_uri = proto.Field(proto.STRING, number=2) - metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + metrics = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) slice_dimensions = proto.RepeatedField(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1/types/model_evaluation_slice.py index 5653c3d2b6..2b6065593c 100644 --- a/google/cloud/aiplatform_v1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1/types/model_evaluation_slice.py @@ -23,7 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"ModelEvaluationSlice",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'ModelEvaluationSlice', + }, ) @@ -54,7 +57,6 @@ class ModelEvaluationSlice(proto.Message): Output only. Timestamp when this ModelEvaluationSlice was created. """ - class Slice(proto.Message): r"""Definition of a slice. @@ -79,13 +81,19 @@ class Slice(proto.Message): name = proto.Field(proto.STRING, number=1) - slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) + slice_ = proto.Field(proto.MESSAGE, number=2, + message=Slice, + ) metrics_schema_uri = proto.Field(proto.STRING, number=3) - metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + metrics = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/model_service.py b/google/cloud/aiplatform_v1/types/model_service.py index 454e014fd5..e3053327c4 100644 --- a/google/cloud/aiplatform_v1/types/model_service.py +++ b/google/cloud/aiplatform_v1/types/model_service.py @@ -27,25 +27,25 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "UploadModelRequest", - "UploadModelOperationMetadata", - "UploadModelResponse", - "GetModelRequest", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", - "DeleteModelRequest", - "ExportModelRequest", - "ExportModelOperationMetadata", - "ExportModelResponse", - "GetModelEvaluationRequest", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "GetModelEvaluationSliceRequest", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", + 'UploadModelRequest', + 'UploadModelOperationMetadata', + 'UploadModelResponse', + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'UpdateModelRequest', + 'DeleteModelRequest', + 'ExportModelRequest', + 'ExportModelOperationMetadata', + 'ExportModelResponse', + 'GetModelEvaluationRequest', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'GetModelEvaluationSliceRequest', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', }, ) @@ -65,7 +65,9 @@ class UploadModelRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) + model = proto.Field(proto.MESSAGE, number=2, + message=gca_model.Model, + ) class UploadModelOperationMetadata(proto.Message): @@ -78,8 +80,8 @@ class UploadModelOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -169,7 +171,9 @@ class ListModelsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -191,7 +195,9 @@ class ListModelsResponse(proto.Message): def raw_page(self): return self - models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) + models = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_model.Model, + ) next_page_token = proto.Field(proto.STRING, number=2) @@ -210,9 +216,13 @@ class UpdateModelRequest(proto.Message): `FieldMask `__. """ - model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) + model = proto.Field(proto.MESSAGE, number=1, + message=gca_model.Model, + ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class DeleteModelRequest(proto.Message): @@ -241,7 +251,6 @@ class ExportModelRequest(proto.Message): Required. The desired output location and configuration. """ - class OutputConfig(proto.Message): r"""Output configuration for the Model export. @@ -273,17 +282,19 @@ class OutputConfig(proto.Message): export_format_id = proto.Field(proto.STRING, number=1) - artifact_destination = proto.Field( - proto.MESSAGE, number=3, message=io.GcsDestination, + artifact_destination = proto.Field(proto.MESSAGE, number=3, + message=io.GcsDestination, ) - image_destination = proto.Field( - proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, + image_destination = proto.Field(proto.MESSAGE, number=4, + message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) - output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) + output_config = proto.Field(proto.MESSAGE, number=2, + message=OutputConfig, + ) class ExportModelOperationMetadata(proto.Message): @@ -298,7 +309,6 @@ class ExportModelOperationMetadata(proto.Message): Output only. Information further describing the output of this Model export. """ - class OutputInfo(proto.Message): r"""Further describes the output of the ExportModel. Supplements ``ExportModelRequest.OutputConfig``. @@ -320,11 +330,13 @@ class OutputInfo(proto.Message): image_output_uri = proto.Field(proto.STRING, number=3) - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) - output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) + output_info = proto.Field(proto.MESSAGE, number=2, + message=OutputInfo, + ) class ExportModelResponse(proto.Message): @@ -379,7 +391,9 @@ class ListModelEvaluationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelEvaluationsResponse(proto.Message): @@ -400,8 +414,8 @@ class ListModelEvaluationsResponse(proto.Message): def raw_page(self): return self - model_evaluations = proto.RepeatedField( - proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, + model_evaluations = proto.RepeatedField(proto.MESSAGE, number=1, + message=model_evaluation.ModelEvaluation, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -456,7 +470,9 @@ class ListModelEvaluationSlicesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelEvaluationSlicesResponse(proto.Message): @@ -477,8 +493,8 @@ class ListModelEvaluationSlicesResponse(proto.Message): def raw_page(self): return self - model_evaluation_slices = proto.RepeatedField( - proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, + model_evaluation_slices = proto.RepeatedField(proto.MESSAGE, number=1, + message=model_evaluation_slice.ModelEvaluationSlice, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/operation.py b/google/cloud/aiplatform_v1/types/operation.py index fe24030e79..2f8211a6ad 100644 --- a/google/cloud/aiplatform_v1/types/operation.py +++ b/google/cloud/aiplatform_v1/types/operation.py @@ -23,8 +23,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", - manifest={"GenericOperationMetadata", "DeleteOperationMetadata",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'GenericOperationMetadata', + 'DeleteOperationMetadata', + }, ) @@ -48,13 +51,17 @@ class GenericOperationMetadata(proto.Message): finish time. """ - partial_failures = proto.RepeatedField( - proto.MESSAGE, number=1, message=status.Status, + partial_failures = proto.RepeatedField(proto.MESSAGE, number=1, + message=status.Status, ) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=2, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) class DeleteOperationMetadata(proto.Message): @@ -65,8 +72,8 @@ class DeleteOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message="GenericOperationMetadata", + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message='GenericOperationMetadata', ) diff --git a/google/cloud/aiplatform_v1/types/pipeline_service.py b/google/cloud/aiplatform_v1/types/pipeline_service.py index b2c6d5bbe3..e757607527 100644 --- a/google/cloud/aiplatform_v1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1/types/pipeline_service.py @@ -23,14 +23,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "CreateTrainingPipelineRequest", - "GetTrainingPipelineRequest", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "DeleteTrainingPipelineRequest", - "CancelTrainingPipelineRequest", + 'CreateTrainingPipelineRequest', + 'GetTrainingPipelineRequest', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'DeleteTrainingPipelineRequest', + 'CancelTrainingPipelineRequest', }, ) @@ -50,8 +50,8 @@ class CreateTrainingPipelineRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - training_pipeline = proto.Field( - proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, + training_pipeline = proto.Field(proto.MESSAGE, number=2, + message=gca_training_pipeline.TrainingPipeline, ) @@ -114,7 +114,9 @@ class ListTrainingPipelinesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListTrainingPipelinesResponse(proto.Message): @@ -135,8 +137,8 @@ class ListTrainingPipelinesResponse(proto.Message): def raw_page(self): return self - training_pipelines = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, + training_pipelines = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_training_pipeline.TrainingPipeline, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/pipeline_state.py b/google/cloud/aiplatform_v1/types/pipeline_state.py index f6a885ae42..6a00f05fef 100644 --- a/google/cloud/aiplatform_v1/types/pipeline_state.py +++ b/google/cloud/aiplatform_v1/types/pipeline_state.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"PipelineState",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'PipelineState', + }, ) diff --git a/google/cloud/aiplatform_v1/types/prediction_service.py b/google/cloud/aiplatform_v1/types/prediction_service.py index 21a01372f4..c7d39c373b 100644 --- a/google/cloud/aiplatform_v1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1/types/prediction_service.py @@ -22,8 +22,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", - manifest={"PredictRequest", "PredictResponse",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'PredictRequest', + 'PredictResponse', + }, ) @@ -58,9 +61,13 @@ class PredictRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) + instances = proto.RepeatedField(proto.MESSAGE, number=2, + message=struct.Value, + ) - parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + parameters = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) class PredictResponse(proto.Message): @@ -80,7 +87,9 @@ class PredictResponse(proto.Message): served this prediction. """ - predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) + predictions = proto.RepeatedField(proto.MESSAGE, number=1, + message=struct.Value, + ) deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/specialist_pool.py b/google/cloud/aiplatform_v1/types/specialist_pool.py index 6265316bd5..b57aa89666 100644 --- a/google/cloud/aiplatform_v1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1/types/specialist_pool.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"SpecialistPool",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'SpecialistPool', + }, ) diff --git a/google/cloud/aiplatform_v1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1/types/specialist_pool_service.py index 69e49bb355..669756640f 100644 --- a/google/cloud/aiplatform_v1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1/types/specialist_pool_service.py @@ -24,16 +24,16 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "CreateSpecialistPoolRequest", - "CreateSpecialistPoolOperationMetadata", - "GetSpecialistPoolRequest", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "DeleteSpecialistPoolRequest", - "UpdateSpecialistPoolRequest", - "UpdateSpecialistPoolOperationMetadata", + 'CreateSpecialistPoolRequest', + 'CreateSpecialistPoolOperationMetadata', + 'GetSpecialistPoolRequest', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'DeleteSpecialistPoolRequest', + 'UpdateSpecialistPoolRequest', + 'UpdateSpecialistPoolOperationMetadata', }, ) @@ -53,8 +53,8 @@ class CreateSpecialistPoolRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - specialist_pool = proto.Field( - proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field(proto.MESSAGE, number=2, + message=gca_specialist_pool.SpecialistPool, ) @@ -67,8 +67,8 @@ class CreateSpecialistPoolOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -114,7 +114,9 @@ class ListSpecialistPoolsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=3) - read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=4, + message=field_mask.FieldMask, + ) class ListSpecialistPoolsResponse(proto.Message): @@ -133,8 +135,8 @@ class ListSpecialistPoolsResponse(proto.Message): def raw_page(self): return self - specialist_pools = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + specialist_pools = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_specialist_pool.SpecialistPool, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -174,11 +176,13 @@ class UpdateSpecialistPoolRequest(proto.Message): resource. """ - specialist_pool = proto.Field( - proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field(proto.MESSAGE, number=1, + message=gca_specialist_pool.SpecialistPool, ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class UpdateSpecialistPoolOperationMetadata(proto.Message): @@ -197,8 +201,8 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): specialist_pool = proto.Field(proto.STRING, number=1) - generic_metadata = proto.Field( - proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=2, + message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1/types/study.py b/google/cloud/aiplatform_v1/types/study.py index 99a688f045..0254866d5b 100644 --- a/google/cloud/aiplatform_v1/types/study.py +++ b/google/cloud/aiplatform_v1/types/study.py @@ -23,8 +23,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", - manifest={"Trial", "StudySpec", "Measurement",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'Trial', + 'StudySpec', + 'Measurement', + }, ) @@ -54,7 +58,6 @@ class Trial(proto.Message): Trial. It's set for a HyperparameterTuningJob's Trial. """ - class State(proto.Enum): r"""Describes a Trial state.""" STATE_UNSPECIFIED = 0 @@ -82,19 +85,31 @@ class Parameter(proto.Message): parameter_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) + value = proto.Field(proto.MESSAGE, number=2, + message=struct.Value, + ) id = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=3, enum=State,) + state = proto.Field(proto.ENUM, number=3, + enum=State, + ) - parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) + parameters = proto.RepeatedField(proto.MESSAGE, number=4, + message=Parameter, + ) - final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) + final_measurement = proto.Field(proto.MESSAGE, number=5, + message='Measurement', + ) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) custom_job = proto.Field(proto.STRING, number=11) @@ -118,7 +133,6 @@ class StudySpec(proto.Message): Describe which measurement selection type will be used """ - class Algorithm(proto.Enum): r"""The available search algorithms for the Study.""" ALGORITHM_UNSPECIFIED = 0 @@ -164,7 +178,6 @@ class MetricSpec(proto.Message): Required. The optimization goal of the metric. """ - class GoalType(proto.Enum): r"""The available types of optimization goals.""" GOAL_TYPE_UNSPECIFIED = 0 @@ -173,7 +186,9 @@ class GoalType(proto.Enum): metric_id = proto.Field(proto.STRING, number=1) - goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) + goal = proto.Field(proto.ENUM, number=2, + enum='StudySpec.MetricSpec.GoalType', + ) class ParameterSpec(proto.Message): r"""Represents a single parameter to optimize. @@ -201,7 +216,6 @@ class ParameterSpec(proto.Message): If two items in conditional_parameter_specs have the same name, they must have disjoint parent_value_condition. """ - class ScaleType(proto.Enum): r"""The type of scaling that should be applied to this parameter.""" SCALE_TYPE_UNSPECIFIED = 0 @@ -284,7 +298,6 @@ class ConditionalParameterSpec(proto.Message): Required. The spec for a conditional parameter. """ - class DiscreteValueCondition(proto.Message): r"""Represents the spec to match discrete values from parent parameter. @@ -326,81 +339,66 @@ class CategoricalValueCondition(proto.Message): values = proto.RepeatedField(proto.STRING, number=1) - parent_discrete_values = proto.Field( - proto.MESSAGE, - number=2, - oneof="parent_value_condition", - message="StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition", + parent_discrete_values = proto.Field(proto.MESSAGE, number=2, oneof='parent_value_condition', + message='StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition', ) - parent_int_values = proto.Field( - proto.MESSAGE, - number=3, - oneof="parent_value_condition", - message="StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition", + parent_int_values = proto.Field(proto.MESSAGE, number=3, oneof='parent_value_condition', + message='StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition', ) - parent_categorical_values = proto.Field( - proto.MESSAGE, - number=4, - oneof="parent_value_condition", - message="StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition", + parent_categorical_values = proto.Field(proto.MESSAGE, number=4, oneof='parent_value_condition', + message='StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition', ) - parameter_spec = proto.Field( - proto.MESSAGE, number=1, message="StudySpec.ParameterSpec", + parameter_spec = proto.Field(proto.MESSAGE, number=1, + message='StudySpec.ParameterSpec', ) - double_value_spec = proto.Field( - proto.MESSAGE, - number=2, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.DoubleValueSpec", + double_value_spec = proto.Field(proto.MESSAGE, number=2, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.DoubleValueSpec', ) - integer_value_spec = proto.Field( - proto.MESSAGE, - number=3, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.IntegerValueSpec", + integer_value_spec = proto.Field(proto.MESSAGE, number=3, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.IntegerValueSpec', ) - categorical_value_spec = proto.Field( - proto.MESSAGE, - number=4, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.CategoricalValueSpec", + categorical_value_spec = proto.Field(proto.MESSAGE, number=4, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.CategoricalValueSpec', ) - discrete_value_spec = proto.Field( - proto.MESSAGE, - number=5, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.DiscreteValueSpec", + discrete_value_spec = proto.Field(proto.MESSAGE, number=5, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.DiscreteValueSpec', ) parameter_id = proto.Field(proto.STRING, number=1) - scale_type = proto.Field( - proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", + scale_type = proto.Field(proto.ENUM, number=6, + enum='StudySpec.ParameterSpec.ScaleType', ) - conditional_parameter_specs = proto.RepeatedField( - proto.MESSAGE, - number=10, - message="StudySpec.ParameterSpec.ConditionalParameterSpec", + conditional_parameter_specs = proto.RepeatedField(proto.MESSAGE, number=10, + message='StudySpec.ParameterSpec.ConditionalParameterSpec', ) - metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, + message=MetricSpec, + ) - parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) + parameters = proto.RepeatedField(proto.MESSAGE, number=2, + message=ParameterSpec, + ) - algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) + algorithm = proto.Field(proto.ENUM, number=3, + enum=Algorithm, + ) - observation_noise = proto.Field(proto.ENUM, number=6, enum=ObservationNoise,) + observation_noise = proto.Field(proto.ENUM, number=6, + enum=ObservationNoise, + ) - measurement_selection_type = proto.Field( - proto.ENUM, number=7, enum=MeasurementSelectionType, + measurement_selection_type = proto.Field(proto.ENUM, number=7, + enum=MeasurementSelectionType, ) @@ -419,7 +417,6 @@ class Measurement(proto.Message): evaluating the objective functions using suggested Parameter values. """ - class Metric(proto.Message): r"""A message representing a metric in the measurement. @@ -438,7 +435,9 @@ class Metric(proto.Message): step_count = proto.Field(proto.INT64, number=2) - metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) + metrics = proto.RepeatedField(proto.MESSAGE, number=3, + message=Metric, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/training_pipeline.py b/google/cloud/aiplatform_v1/types/training_pipeline.py index 9a41f231a5..b0135a926b 100644 --- a/google/cloud/aiplatform_v1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1/types/training_pipeline.py @@ -28,14 +28,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", + package='google.cloud.aiplatform.v1', manifest={ - "TrainingPipeline", - "InputDataConfig", - "FractionSplit", - "FilterSplit", - "PredefinedSplit", - "TimestampSplit", + 'TrainingPipeline', + 'InputDataConfig', + 'FractionSplit', + 'FilterSplit', + 'PredefinedSplit', + 'TimestampSplit', }, ) @@ -154,32 +154,52 @@ class TrainingPipeline(proto.Message): display_name = proto.Field(proto.STRING, number=2) - input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) + input_data_config = proto.Field(proto.MESSAGE, number=3, + message='InputDataConfig', + ) training_task_definition = proto.Field(proto.STRING, number=4) - training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + training_task_inputs = proto.Field(proto.MESSAGE, number=5, + message=struct.Value, + ) - training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + training_task_metadata = proto.Field(proto.MESSAGE, number=6, + message=struct.Value, + ) - model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) + model_to_upload = proto.Field(proto.MESSAGE, number=7, + message=model.Model, + ) - state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) + state = proto.Field(proto.ENUM, number=9, + enum=pipeline_state.PipelineState, + ) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=10, + message=status.Status, + ) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=15) - encryption_spec = proto.Field( - proto.MESSAGE, number=18, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=18, + message=gca_encryption_spec.EncryptionSpec, ) @@ -303,28 +323,28 @@ class InputDataConfig(proto.Message): ``annotation_schema_uri``. """ - fraction_split = proto.Field( - proto.MESSAGE, number=2, oneof="split", message="FractionSplit", + fraction_split = proto.Field(proto.MESSAGE, number=2, oneof='split', + message='FractionSplit', ) - filter_split = proto.Field( - proto.MESSAGE, number=3, oneof="split", message="FilterSplit", + filter_split = proto.Field(proto.MESSAGE, number=3, oneof='split', + message='FilterSplit', ) - predefined_split = proto.Field( - proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", + predefined_split = proto.Field(proto.MESSAGE, number=4, oneof='split', + message='PredefinedSplit', ) - timestamp_split = proto.Field( - proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", + timestamp_split = proto.Field(proto.MESSAGE, number=5, oneof='split', + message='TimestampSplit', ) - gcs_destination = proto.Field( - proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, + gcs_destination = proto.Field(proto.MESSAGE, number=8, oneof='destination', + message=io.GcsDestination, ) - bigquery_destination = proto.Field( - proto.MESSAGE, number=10, oneof="destination", message=io.BigQueryDestination, + bigquery_destination = proto.Field(proto.MESSAGE, number=10, oneof='destination', + message=io.BigQueryDestination, ) dataset_id = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1/types/user_action_reference.py b/google/cloud/aiplatform_v1/types/user_action_reference.py index da59ac6ac6..89d799178a 100644 --- a/google/cloud/aiplatform_v1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1/types/user_action_reference.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1", manifest={"UserActionReference",}, + package='google.cloud.aiplatform.v1', + manifest={ + 'UserActionReference', + }, ) @@ -44,9 +47,9 @@ class UserActionReference(proto.Message): "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1, oneof="reference") + operation = proto.Field(proto.STRING, number=1, oneof='reference') - data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") + data_labeling_job = proto.Field(proto.STRING, number=2, oneof='reference') method = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 621f1e96f8..0dbcbec2d6 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -18,6 +18,7 @@ from .services.dataset_service import DatasetServiceClient from .services.endpoint_service import EndpointServiceClient from .services.job_service import JobServiceClient +from .services.metadata_service import MetadataServiceClient from .services.migration_service import MigrationServiceClient from .services.model_service import ModelServiceClient from .services.pipeline_service import PipelineServiceClient @@ -27,8 +28,10 @@ from .types.accelerator_type import AcceleratorType from .types.annotation import Annotation from .types.annotation_spec import AnnotationSpec +from .types.artifact import Artifact from .types.batch_prediction_job import BatchPredictionJob from .types.completion_stats import CompletionStats +from .types.context import Context from .types.custom_job import ContainerSpec from .types.custom_job import CustomJob from .types.custom_job import CustomJobSpec @@ -79,6 +82,8 @@ from .types.endpoint_service import UndeployModelResponse from .types.endpoint_service import UpdateEndpointRequest from .types.env_var import EnvVar +from .types.event import Event +from .types.execution import Execution from .types.explanation import Attribution from .types.explanation import Explanation from .types.explanation import ExplanationMetadataOverride @@ -92,6 +97,7 @@ from .types.explanation import SmoothGradConfig from .types.explanation import XraiAttribution from .types.explanation_metadata import ExplanationMetadata +from .types.feature_monitoring_stats import FeatureStatsAnomaly from .types.hyperparameter_tuning_job import HyperparameterTuningJob from .types.io import BigQueryDestination from .types.io import BigQuerySource @@ -106,14 +112,17 @@ from .types.job_service import CreateCustomJobRequest from .types.job_service import CreateDataLabelingJobRequest from .types.job_service import CreateHyperparameterTuningJobRequest +from .types.job_service import CreateModelDeploymentMonitoringJobRequest from .types.job_service import DeleteBatchPredictionJobRequest from .types.job_service import DeleteCustomJobRequest from .types.job_service import DeleteDataLabelingJobRequest from .types.job_service import DeleteHyperparameterTuningJobRequest +from .types.job_service import DeleteModelDeploymentMonitoringJobRequest from .types.job_service import GetBatchPredictionJobRequest from .types.job_service import GetCustomJobRequest from .types.job_service import GetDataLabelingJobRequest from .types.job_service import GetHyperparameterTuningJobRequest +from .types.job_service import GetModelDeploymentMonitoringJobRequest from .types.job_service import ListBatchPredictionJobsRequest from .types.job_service import ListBatchPredictionJobsResponse from .types.job_service import ListCustomJobsRequest @@ -122,7 +131,16 @@ from .types.job_service import ListDataLabelingJobsResponse from .types.job_service import ListHyperparameterTuningJobsRequest from .types.job_service import ListHyperparameterTuningJobsResponse +from .types.job_service import ListModelDeploymentMonitoringJobsRequest +from .types.job_service import ListModelDeploymentMonitoringJobsResponse +from .types.job_service import PauseModelDeploymentMonitoringJobRequest +from .types.job_service import ResumeModelDeploymentMonitoringJobRequest +from .types.job_service import SearchModelDeploymentMonitoringStatsAnomaliesRequest +from .types.job_service import SearchModelDeploymentMonitoringStatsAnomaliesResponse +from .types.job_service import UpdateModelDeploymentMonitoringJobOperationMetadata +from .types.job_service import UpdateModelDeploymentMonitoringJobRequest from .types.job_state import JobState +from .types.lineage_subgraph import LineageSubgraph from .types.machine_resources import AutomaticResources from .types.machine_resources import AutoscalingMetricSpec from .types.machine_resources import BatchDedicatedResources @@ -131,6 +149,43 @@ from .types.machine_resources import MachineSpec from .types.machine_resources import ResourcesConsumed from .types.manual_batch_tuning_parameters import ManualBatchTuningParameters +from .types.metadata_schema import MetadataSchema +from .types.metadata_service import AddContextArtifactsAndExecutionsRequest +from .types.metadata_service import AddContextArtifactsAndExecutionsResponse +from .types.metadata_service import AddContextChildrenRequest +from .types.metadata_service import AddContextChildrenResponse +from .types.metadata_service import AddExecutionEventsRequest +from .types.metadata_service import AddExecutionEventsResponse +from .types.metadata_service import CreateArtifactRequest +from .types.metadata_service import CreateContextRequest +from .types.metadata_service import CreateExecutionRequest +from .types.metadata_service import CreateMetadataSchemaRequest +from .types.metadata_service import CreateMetadataStoreOperationMetadata +from .types.metadata_service import CreateMetadataStoreRequest +from .types.metadata_service import DeleteContextRequest +from .types.metadata_service import DeleteMetadataStoreOperationMetadata +from .types.metadata_service import DeleteMetadataStoreRequest +from .types.metadata_service import GetArtifactRequest +from .types.metadata_service import GetContextRequest +from .types.metadata_service import GetExecutionRequest +from .types.metadata_service import GetMetadataSchemaRequest +from .types.metadata_service import GetMetadataStoreRequest +from .types.metadata_service import ListArtifactsRequest +from .types.metadata_service import ListArtifactsResponse +from .types.metadata_service import ListContextsRequest +from .types.metadata_service import ListContextsResponse +from .types.metadata_service import ListExecutionsRequest +from .types.metadata_service import ListExecutionsResponse +from .types.metadata_service import ListMetadataSchemasRequest +from .types.metadata_service import ListMetadataSchemasResponse +from .types.metadata_service import ListMetadataStoresRequest +from .types.metadata_service import ListMetadataStoresResponse +from .types.metadata_service import QueryContextLineageSubgraphRequest +from .types.metadata_service import QueryExecutionInputsAndOutputsRequest +from .types.metadata_service import UpdateArtifactRequest +from .types.metadata_service import UpdateContextRequest +from .types.metadata_service import UpdateExecutionRequest +from .types.metadata_store import MetadataStore from .types.migratable_resource import MigratableResource from .types.migration_service import BatchMigrateResourcesOperationMetadata from .types.migration_service import BatchMigrateResourcesRequest @@ -143,8 +198,18 @@ from .types.model import ModelContainerSpec from .types.model import Port from .types.model import PredictSchemata +from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringBigQueryTable +from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringJob +from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringObjectiveConfig +from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringObjectiveType +from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringScheduleConfig +from .types.model_deployment_monitoring_job import ModelMonitoringStatsAnomalies from .types.model_evaluation import ModelEvaluation from .types.model_evaluation_slice import ModelEvaluationSlice +from .types.model_monitoring import ModelMonitoringAlertConfig +from .types.model_monitoring import ModelMonitoringObjectiveConfig +from .types.model_monitoring import SamplingStrategy +from .types.model_monitoring import ThresholdConfig from .types.model_service import DeleteModelRequest from .types.model_service import ExportModelOperationMetadata from .types.model_service import ExportModelRequest @@ -220,206 +285,271 @@ __all__ = ( - "AcceleratorType", - "ActiveLearningConfig", - "AddTrialMeasurementRequest", - "Annotation", - "AnnotationSpec", - "Attribution", - "AutomaticResources", - "AutoscalingMetricSpec", - "BatchDedicatedResources", - "BatchMigrateResourcesOperationMetadata", - "BatchMigrateResourcesRequest", - "BatchMigrateResourcesResponse", - "BatchPredictionJob", - "BigQueryDestination", - "BigQuerySource", - "CancelBatchPredictionJobRequest", - "CancelCustomJobRequest", - "CancelDataLabelingJobRequest", - "CancelHyperparameterTuningJobRequest", - "CancelTrainingPipelineRequest", - "CheckTrialEarlyStoppingStateMetatdata", - "CheckTrialEarlyStoppingStateRequest", - "CheckTrialEarlyStoppingStateResponse", - "CompleteTrialRequest", - "CompletionStats", - "ContainerRegistryDestination", - "ContainerSpec", - "CreateBatchPredictionJobRequest", - "CreateCustomJobRequest", - "CreateDataLabelingJobRequest", - "CreateDatasetOperationMetadata", - "CreateDatasetRequest", - "CreateEndpointOperationMetadata", - "CreateEndpointRequest", - "CreateHyperparameterTuningJobRequest", - "CreateSpecialistPoolOperationMetadata", - "CreateSpecialistPoolRequest", - "CreateStudyRequest", - "CreateTrainingPipelineRequest", - "CreateTrialRequest", - "CustomJob", - "CustomJobSpec", - "DataItem", - "DataLabelingJob", - "Dataset", - "DatasetServiceClient", - "DedicatedResources", - "DeleteBatchPredictionJobRequest", - "DeleteCustomJobRequest", - "DeleteDataLabelingJobRequest", - "DeleteDatasetRequest", - "DeleteEndpointRequest", - "DeleteHyperparameterTuningJobRequest", - "DeleteModelRequest", - "DeleteOperationMetadata", - "DeleteSpecialistPoolRequest", - "DeleteStudyRequest", - "DeleteTrainingPipelineRequest", - "DeleteTrialRequest", - "DeployModelOperationMetadata", - "DeployModelRequest", - "DeployModelResponse", - "DeployedModel", - "DeployedModelRef", - "DiskSpec", - "EncryptionSpec", - "Endpoint", - "EndpointServiceClient", - "EnvVar", - "ExplainRequest", - "ExplainResponse", - "Explanation", - "ExplanationMetadata", - "ExplanationMetadataOverride", - "ExplanationParameters", - "ExplanationSpec", - "ExplanationSpecOverride", - "ExportDataConfig", - "ExportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportModelOperationMetadata", - "ExportModelRequest", - "ExportModelResponse", - "FeatureNoiseSigma", - "FilterSplit", - "FractionSplit", - "GcsDestination", - "GcsSource", - "GenericOperationMetadata", - "GetAnnotationSpecRequest", - "GetBatchPredictionJobRequest", - "GetCustomJobRequest", - "GetDataLabelingJobRequest", - "GetDatasetRequest", - "GetEndpointRequest", - "GetHyperparameterTuningJobRequest", - "GetModelEvaluationRequest", - "GetModelEvaluationSliceRequest", - "GetModelRequest", - "GetSpecialistPoolRequest", - "GetStudyRequest", - "GetTrainingPipelineRequest", - "GetTrialRequest", - "HyperparameterTuningJob", - "ImportDataConfig", - "ImportDataOperationMetadata", - "ImportDataRequest", - "ImportDataResponse", - "InputDataConfig", - "IntegratedGradientsAttribution", - "JobServiceClient", - "JobState", - "ListAnnotationsRequest", - "ListAnnotationsResponse", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "ListDataItemsRequest", - "ListDataItemsResponse", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "ListDatasetsRequest", - "ListDatasetsResponse", - "ListEndpointsRequest", - "ListEndpointsResponse", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "ListModelsRequest", - "ListModelsResponse", - "ListOptimalTrialsRequest", - "ListOptimalTrialsResponse", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "ListStudiesRequest", - "ListStudiesResponse", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "ListTrialsRequest", - "ListTrialsResponse", - "LookupStudyRequest", - "MachineSpec", - "ManualBatchTuningParameters", - "Measurement", - "MigratableResource", - "MigrateResourceRequest", - "MigrateResourceResponse", - "MigrationServiceClient", - "Model", - "ModelContainerSpec", - "ModelEvaluation", - "ModelEvaluationSlice", - "ModelExplanation", - "ModelServiceClient", - "PipelineServiceClient", - "PipelineState", - "Port", - "PredefinedSplit", - "PredictRequest", - "PredictResponse", - "PredictSchemata", - "PredictionServiceClient", - "PythonPackageSpec", - "ResourcesConsumed", - "SampleConfig", - "SampledShapleyAttribution", - "Scheduling", - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", - "SmoothGradConfig", - "SpecialistPool", - "SpecialistPoolServiceClient", - "StopTrialRequest", - "Study", - "StudySpec", - "SuggestTrialsMetadata", - "SuggestTrialsRequest", - "SuggestTrialsResponse", - "TimestampSplit", - "TrainingConfig", - "TrainingPipeline", - "Trial", - "UndeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UpdateDatasetRequest", - "UpdateEndpointRequest", - "UpdateModelRequest", - "UpdateSpecialistPoolOperationMetadata", - "UpdateSpecialistPoolRequest", - "UploadModelOperationMetadata", - "UploadModelRequest", - "UploadModelResponse", - "UserActionReference", - "WorkerPoolSpec", - "XraiAttribution", - "VizierServiceClient", + 'AcceleratorType', + 'ActiveLearningConfig', + 'AddContextArtifactsAndExecutionsRequest', + 'AddContextArtifactsAndExecutionsResponse', + 'AddContextChildrenRequest', + 'AddContextChildrenResponse', + 'AddExecutionEventsRequest', + 'AddExecutionEventsResponse', + 'AddTrialMeasurementRequest', + 'Annotation', + 'AnnotationSpec', + 'Artifact', + 'Attribution', + 'AutomaticResources', + 'AutoscalingMetricSpec', + 'BatchDedicatedResources', + 'BatchMigrateResourcesOperationMetadata', + 'BatchMigrateResourcesRequest', + 'BatchMigrateResourcesResponse', + 'BatchPredictionJob', + 'BigQueryDestination', + 'BigQuerySource', + 'CancelBatchPredictionJobRequest', + 'CancelCustomJobRequest', + 'CancelDataLabelingJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CancelTrainingPipelineRequest', + 'CheckTrialEarlyStoppingStateMetatdata', + 'CheckTrialEarlyStoppingStateRequest', + 'CheckTrialEarlyStoppingStateResponse', + 'CompleteTrialRequest', + 'CompletionStats', + 'ContainerRegistryDestination', + 'ContainerSpec', + 'Context', + 'CreateArtifactRequest', + 'CreateBatchPredictionJobRequest', + 'CreateContextRequest', + 'CreateCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'CreateDatasetOperationMetadata', + 'CreateDatasetRequest', + 'CreateEndpointOperationMetadata', + 'CreateEndpointRequest', + 'CreateExecutionRequest', + 'CreateHyperparameterTuningJobRequest', + 'CreateMetadataSchemaRequest', + 'CreateMetadataStoreOperationMetadata', + 'CreateMetadataStoreRequest', + 'CreateModelDeploymentMonitoringJobRequest', + 'CreateSpecialistPoolOperationMetadata', + 'CreateSpecialistPoolRequest', + 'CreateStudyRequest', + 'CreateTrainingPipelineRequest', + 'CreateTrialRequest', + 'CustomJob', + 'CustomJobSpec', + 'DataItem', + 'DataLabelingJob', + 'Dataset', + 'DatasetServiceClient', + 'DedicatedResources', + 'DeleteBatchPredictionJobRequest', + 'DeleteContextRequest', + 'DeleteCustomJobRequest', + 'DeleteDataLabelingJobRequest', + 'DeleteDatasetRequest', + 'DeleteEndpointRequest', + 'DeleteHyperparameterTuningJobRequest', + 'DeleteMetadataStoreOperationMetadata', + 'DeleteMetadataStoreRequest', + 'DeleteModelDeploymentMonitoringJobRequest', + 'DeleteModelRequest', + 'DeleteOperationMetadata', + 'DeleteSpecialistPoolRequest', + 'DeleteStudyRequest', + 'DeleteTrainingPipelineRequest', + 'DeleteTrialRequest', + 'DeployModelOperationMetadata', + 'DeployModelRequest', + 'DeployModelResponse', + 'DeployedModel', + 'DeployedModelRef', + 'DiskSpec', + 'EncryptionSpec', + 'Endpoint', + 'EndpointServiceClient', + 'EnvVar', + 'Event', + 'Execution', + 'ExplainRequest', + 'ExplainResponse', + 'Explanation', + 'ExplanationMetadata', + 'ExplanationMetadataOverride', + 'ExplanationParameters', + 'ExplanationSpec', + 'ExplanationSpecOverride', + 'ExportDataConfig', + 'ExportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'ExportModelOperationMetadata', + 'ExportModelRequest', + 'ExportModelResponse', + 'FeatureNoiseSigma', + 'FeatureStatsAnomaly', + 'FilterSplit', + 'FractionSplit', + 'GcsDestination', + 'GcsSource', + 'GenericOperationMetadata', + 'GetAnnotationSpecRequest', + 'GetArtifactRequest', + 'GetBatchPredictionJobRequest', + 'GetContextRequest', + 'GetCustomJobRequest', + 'GetDataLabelingJobRequest', + 'GetDatasetRequest', + 'GetEndpointRequest', + 'GetExecutionRequest', + 'GetHyperparameterTuningJobRequest', + 'GetMetadataSchemaRequest', + 'GetMetadataStoreRequest', + 'GetModelDeploymentMonitoringJobRequest', + 'GetModelEvaluationRequest', + 'GetModelEvaluationSliceRequest', + 'GetModelRequest', + 'GetSpecialistPoolRequest', + 'GetStudyRequest', + 'GetTrainingPipelineRequest', + 'GetTrialRequest', + 'HyperparameterTuningJob', + 'ImportDataConfig', + 'ImportDataOperationMetadata', + 'ImportDataRequest', + 'ImportDataResponse', + 'InputDataConfig', + 'IntegratedGradientsAttribution', + 'JobServiceClient', + 'JobState', + 'LineageSubgraph', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', + 'ListArtifactsRequest', + 'ListArtifactsResponse', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'ListContextsRequest', + 'ListContextsResponse', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'ListExecutionsRequest', + 'ListExecutionsResponse', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'ListMetadataSchemasRequest', + 'ListMetadataSchemasResponse', + 'ListMetadataStoresRequest', + 'ListMetadataStoresResponse', + 'ListModelDeploymentMonitoringJobsRequest', + 'ListModelDeploymentMonitoringJobsResponse', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'ListModelsRequest', + 'ListModelsResponse', + 'ListOptimalTrialsRequest', + 'ListOptimalTrialsResponse', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'ListStudiesRequest', + 'ListStudiesResponse', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'ListTrialsRequest', + 'ListTrialsResponse', + 'LookupStudyRequest', + 'MachineSpec', + 'ManualBatchTuningParameters', + 'Measurement', + 'MetadataSchema', + 'MetadataStore', + 'MigratableResource', + 'MigrateResourceRequest', + 'MigrateResourceResponse', + 'MigrationServiceClient', + 'Model', + 'ModelContainerSpec', + 'ModelDeploymentMonitoringBigQueryTable', + 'ModelDeploymentMonitoringJob', + 'ModelDeploymentMonitoringObjectiveConfig', + 'ModelDeploymentMonitoringObjectiveType', + 'ModelDeploymentMonitoringScheduleConfig', + 'ModelEvaluation', + 'ModelEvaluationSlice', + 'ModelExplanation', + 'ModelMonitoringAlertConfig', + 'ModelMonitoringObjectiveConfig', + 'ModelMonitoringStatsAnomalies', + 'ModelServiceClient', + 'PauseModelDeploymentMonitoringJobRequest', + 'PipelineServiceClient', + 'PipelineState', + 'Port', + 'PredefinedSplit', + 'PredictRequest', + 'PredictResponse', + 'PredictSchemata', + 'PredictionServiceClient', + 'PythonPackageSpec', + 'QueryContextLineageSubgraphRequest', + 'QueryExecutionInputsAndOutputsRequest', + 'ResourcesConsumed', + 'ResumeModelDeploymentMonitoringJobRequest', + 'SampleConfig', + 'SampledShapleyAttribution', + 'SamplingStrategy', + 'Scheduling', + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', + 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', + 'SmoothGradConfig', + 'SpecialistPool', + 'SpecialistPoolServiceClient', + 'StopTrialRequest', + 'Study', + 'StudySpec', + 'SuggestTrialsMetadata', + 'SuggestTrialsRequest', + 'SuggestTrialsResponse', + 'ThresholdConfig', + 'TimestampSplit', + 'TrainingConfig', + 'TrainingPipeline', + 'Trial', + 'UndeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UpdateArtifactRequest', + 'UpdateContextRequest', + 'UpdateDatasetRequest', + 'UpdateEndpointRequest', + 'UpdateExecutionRequest', + 'UpdateModelDeploymentMonitoringJobOperationMetadata', + 'UpdateModelDeploymentMonitoringJobRequest', + 'UpdateModelRequest', + 'UpdateSpecialistPoolOperationMetadata', + 'UpdateSpecialistPoolRequest', + 'UploadModelOperationMetadata', + 'UploadModelRequest', + 'UploadModelResponse', + 'UserActionReference', + 'VizierServiceClient', + 'WorkerPoolSpec', + 'XraiAttribution', +'MetadataServiceClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py index 597f654cb9..9d1f004f6a 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import DatasetServiceAsyncClient __all__ = ( - "DatasetServiceClient", - "DatasetServiceAsyncClient", + 'DatasetServiceClient', + 'DatasetServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index d91df4b644..2eb9ce6f7a 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,42 +60,26 @@ class DatasetServiceAsyncClient: annotation_path = staticmethod(DatasetServiceClient.annotation_path) parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) - parse_annotation_spec_path = staticmethod( - DatasetServiceClient.parse_annotation_spec_path - ) + parse_annotation_spec_path = staticmethod(DatasetServiceClient.parse_annotation_spec_path) data_item_path = staticmethod(DatasetServiceClient.data_item_path) parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) dataset_path = staticmethod(DatasetServiceClient.dataset_path) parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) - common_billing_account_path = staticmethod( - DatasetServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - DatasetServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(DatasetServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(DatasetServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - DatasetServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(DatasetServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - DatasetServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - DatasetServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(DatasetServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(DatasetServiceClient.parse_common_organization_path) common_project_path = staticmethod(DatasetServiceClient.common_project_path) - parse_common_project_path = staticmethod( - DatasetServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(DatasetServiceClient.parse_common_project_path) common_location_path = staticmethod(DatasetServiceClient.common_location_path) - parse_common_location_path = staticmethod( - DatasetServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(DatasetServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -138,18 +122,14 @@ def transport(self) -> DatasetServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) - ) + get_transport_class = functools.partial(type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -188,18 +168,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_dataset( - self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_dataset(self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a Dataset. Args: @@ -240,10 +220,8 @@ async def create_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.CreateDatasetRequest(request) @@ -266,11 +244,18 @@ async def create_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -283,15 +268,14 @@ async def create_dataset( # Done; return the response. return response - async def get_dataset( - self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + async def get_dataset(self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -323,10 +307,8 @@ async def get_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.GetDatasetRequest(request) @@ -347,25 +329,31 @@ async def get_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def update_dataset( - self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + async def update_dataset(self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -410,10 +398,8 @@ async def update_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.UpdateDatasetRequest(request) @@ -436,26 +422,30 @@ async def update_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("dataset.name", request.dataset.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('dataset.name', request.dataset.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_datasets( - self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsAsyncPager: + async def list_datasets(self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: r"""Lists Datasets in a Location. Args: @@ -490,10 +480,8 @@ async def list_datasets( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ListDatasetsRequest(request) @@ -514,30 +502,39 @@ async def list_datasets( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDatasetsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_dataset( - self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_dataset(self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Dataset. Args: @@ -583,10 +580,8 @@ async def delete_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.DeleteDatasetRequest(request) @@ -607,11 +602,18 @@ async def delete_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -624,16 +626,15 @@ async def delete_dataset( # Done; return the response. return response - async def import_data( - self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def import_data(self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Imports data into a Dataset. Args: @@ -677,10 +678,8 @@ async def import_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ImportDataRequest(request) @@ -704,11 +703,18 @@ async def import_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -721,16 +727,15 @@ async def import_data( # Done; return the response. return response - async def export_data( - self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_data(self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports data from a Dataset. Args: @@ -773,10 +778,8 @@ async def export_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ExportDataRequest(request) @@ -799,11 +802,18 @@ async def export_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -816,15 +826,14 @@ async def export_data( # Done; return the response. return response - async def list_data_items( - self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsAsyncPager: + async def list_data_items(self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: r"""Lists DataItems in a Dataset. Args: @@ -860,10 +869,8 @@ async def list_data_items( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ListDataItemsRequest(request) @@ -884,30 +891,39 @@ async def list_data_items( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataItemsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def get_annotation_spec( - self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + async def get_annotation_spec(self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -940,10 +956,8 @@ async def get_annotation_spec( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.GetAnnotationSpecRequest(request) @@ -964,24 +978,30 @@ async def get_annotation_spec( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_annotations( - self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsAsyncPager: + async def list_annotations(self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1017,10 +1037,8 @@ async def list_annotations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = dataset_service.ListAnnotationsRequest(request) @@ -1041,30 +1059,47 @@ async def list_annotations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListAnnotationsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("DatasetServiceAsyncClient",) +__all__ = ( + 'DatasetServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 37aecfc5e5..9d139e6b64 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,14 +60,13 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry['grpc'] = DatasetServiceGrpcTransport + _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry["grpc"] = DatasetServiceGrpcTransport - _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -118,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -153,8 +152,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: DatasetServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,149 +169,110 @@ def transport(self) -> DatasetServiceTransport: return self._transport @staticmethod - def annotation_path( - project: str, location: str, dataset: str, data_item: str, annotation: str, - ) -> str: + def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: """Return a fully-qualified annotation string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( - project=project, - location=location, - dataset=dataset, - data_item=data_item, - annotation=annotation, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) @staticmethod - def parse_annotation_path(path: str) -> Dict[str, str]: + def parse_annotation_path(path: str) -> Dict[str,str]: """Parse a annotation path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def annotation_spec_path( - project: str, location: str, dataset: str, annotation_spec: str, - ) -> str: + def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: """Return a fully-qualified annotation_spec string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( - project=project, - location=location, - dataset=dataset, - annotation_spec=annotation_spec, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) @staticmethod - def parse_annotation_spec_path(path: str) -> Dict[str, str]: + def parse_annotation_spec_path(path: str) -> Dict[str,str]: """Parse a annotation_spec path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def data_item_path( - project: str, location: str, dataset: str, data_item: str, - ) -> str: + def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: """Return a fully-qualified data_item string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( - project=project, location=location, dataset=dataset, data_item=data_item, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) @staticmethod - def parse_data_item_path(path: str) -> Dict[str, str]: + def parse_data_item_path(path: str) -> Dict[str,str]: """Parse a data_item path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -355,9 +316,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -367,9 +326,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -381,9 +338,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -395,10 +350,8 @@ def __init__( if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -417,16 +370,15 @@ def __init__( client_info=client_info, ) - def create_dataset( - self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_dataset(self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a Dataset. Args: @@ -467,10 +419,8 @@ def create_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.CreateDatasetRequest. @@ -494,11 +444,18 @@ def create_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -511,15 +468,14 @@ def create_dataset( # Done; return the response. return response - def get_dataset( - self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset(self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -551,10 +507,8 @@ def get_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetDatasetRequest. @@ -576,25 +530,31 @@ def get_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def update_dataset( - self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset(self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -639,10 +599,8 @@ def update_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.UpdateDatasetRequest. @@ -666,26 +624,30 @@ def update_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("dataset.name", request.dataset.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('dataset.name', request.dataset.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_datasets( - self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets(self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -720,10 +682,8 @@ def list_datasets( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDatasetsRequest. @@ -745,30 +705,39 @@ def list_datasets( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_dataset( - self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_dataset(self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Dataset. Args: @@ -814,10 +783,8 @@ def delete_dataset( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.DeleteDatasetRequest. @@ -839,11 +806,18 @@ def delete_dataset( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -856,16 +830,15 @@ def delete_dataset( # Done; return the response. return response - def import_data( - self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def import_data(self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Imports data into a Dataset. Args: @@ -909,10 +882,8 @@ def import_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ImportDataRequest. @@ -936,11 +907,18 @@ def import_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -953,16 +931,15 @@ def import_data( # Done; return the response. return response - def export_data( - self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_data(self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports data from a Dataset. Args: @@ -1005,10 +982,8 @@ def export_data( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ExportDataRequest. @@ -1032,11 +1007,18 @@ def export_data( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1049,15 +1031,14 @@ def export_data( # Done; return the response. return response - def list_data_items( - self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items(self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -1093,10 +1074,8 @@ def list_data_items( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDataItemsRequest. @@ -1118,30 +1097,39 @@ def list_data_items( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec( - self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec(self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -1174,10 +1162,8 @@ def get_annotation_spec( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetAnnotationSpecRequest. @@ -1199,24 +1185,30 @@ def get_annotation_spec( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_annotations( - self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations(self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1252,10 +1244,8 @@ def list_annotations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListAnnotationsRequest. @@ -1277,30 +1267,47 @@ def list_annotations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("DatasetServiceClient",) +__all__ = ( + 'DatasetServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index 63560b32ba..aa9114bc5f 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import data_item @@ -49,15 +40,12 @@ class ListDatasetsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., dataset_service.ListDatasetsResponse], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., dataset_service.ListDatasetsResponse], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -91,7 +79,7 @@ def __iter__(self) -> Iterable[dataset.Dataset]: yield from page.datasets def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDatasetsAsyncPager: @@ -111,15 +99,12 @@ class ListDatasetsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -157,7 +142,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataItemsPager: @@ -177,15 +162,12 @@ class ListDataItemsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., dataset_service.ListDataItemsResponse], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., dataset_service.ListDataItemsResponse], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -219,7 +201,7 @@ def __iter__(self) -> Iterable[data_item.DataItem]: yield from page.data_items def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataItemsAsyncPager: @@ -239,15 +221,12 @@ class ListDataItemsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -285,7 +264,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListAnnotationsPager: @@ -305,15 +284,12 @@ class ListAnnotationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., dataset_service.ListAnnotationsResponse], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., dataset_service.ListAnnotationsResponse], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -347,7 +323,7 @@ def __iter__(self) -> Iterable[annotation.Annotation]: yield from page.annotations def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListAnnotationsAsyncPager: @@ -367,15 +343,12 @@ class ListAnnotationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -413,4 +386,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py index a4461d2ced..5f02a0f0d9 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] -_transport_registry["grpc"] = DatasetServiceGrpcTransport -_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = DatasetServiceGrpcTransport +_transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport __all__ = ( - "DatasetServiceTransport", - "DatasetServiceGrpcTransport", - "DatasetServiceGrpcAsyncIOTransport", + 'DatasetServiceTransport', + 'DatasetServiceGrpcTransport', + 'DatasetServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index 56f567959a..74909b2980 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -74,73 +74,92 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_dataset: gapic_v1.method.wrap_method( - self.create_dataset, default_timeout=5.0, client_info=client_info, + self.create_dataset, + default_timeout=5.0, + client_info=client_info, ), self.get_dataset: gapic_v1.method.wrap_method( - self.get_dataset, default_timeout=5.0, client_info=client_info, + self.get_dataset, + default_timeout=5.0, + client_info=client_info, ), self.update_dataset: gapic_v1.method.wrap_method( - self.update_dataset, default_timeout=5.0, client_info=client_info, + self.update_dataset, + default_timeout=5.0, + client_info=client_info, ), self.list_datasets: gapic_v1.method.wrap_method( - self.list_datasets, default_timeout=5.0, client_info=client_info, + self.list_datasets, + default_timeout=5.0, + client_info=client_info, ), self.delete_dataset: gapic_v1.method.wrap_method( - self.delete_dataset, default_timeout=5.0, client_info=client_info, + self.delete_dataset, + default_timeout=5.0, + client_info=client_info, ), self.import_data: gapic_v1.method.wrap_method( - self.import_data, default_timeout=5.0, client_info=client_info, + self.import_data, + default_timeout=5.0, + client_info=client_info, ), self.export_data: gapic_v1.method.wrap_method( - self.export_data, default_timeout=5.0, client_info=client_info, + self.export_data, + default_timeout=5.0, + client_info=client_info, ), self.list_data_items: gapic_v1.method.wrap_method( - self.list_data_items, default_timeout=5.0, client_info=client_info, + self.list_data_items, + default_timeout=5.0, + client_info=client_info, ), self.get_annotation_spec: gapic_v1.method.wrap_method( - self.get_annotation_spec, default_timeout=5.0, client_info=client_info, + self.get_annotation_spec, + default_timeout=5.0, + client_info=client_info, ), self.list_annotations: gapic_v1.method.wrap_method( - self.list_annotations, default_timeout=5.0, client_info=client_info, + self.list_annotations, + default_timeout=5.0, + client_info=client_info, ), + } @property @@ -149,106 +168,96 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_dataset( - self, - ) -> typing.Callable[ - [dataset_service.CreateDatasetRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def create_dataset(self) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_dataset( - self, - ) -> typing.Callable[ - [dataset_service.GetDatasetRequest], - typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], - ]: + def get_dataset(self) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[ + dataset.Dataset, + typing.Awaitable[dataset.Dataset] + ]]: raise NotImplementedError() @property - def update_dataset( - self, - ) -> typing.Callable[ - [dataset_service.UpdateDatasetRequest], - typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], - ]: + def update_dataset(self) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[ + gca_dataset.Dataset, + typing.Awaitable[gca_dataset.Dataset] + ]]: raise NotImplementedError() @property - def list_datasets( - self, - ) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], - typing.Union[ - dataset_service.ListDatasetsResponse, - typing.Awaitable[dataset_service.ListDatasetsResponse], - ], - ]: + def list_datasets(self) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse] + ]]: raise NotImplementedError() @property - def delete_dataset( - self, - ) -> typing.Callable[ - [dataset_service.DeleteDatasetRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_dataset(self) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def import_data( - self, - ) -> typing.Callable[ - [dataset_service.ImportDataRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def import_data(self) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def export_data( - self, - ) -> typing.Callable[ - [dataset_service.ExportDataRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def export_data(self) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def list_data_items( - self, - ) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], - typing.Union[ - dataset_service.ListDataItemsResponse, - typing.Awaitable[dataset_service.ListDataItemsResponse], - ], - ]: + def list_data_items(self) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse] + ]]: raise NotImplementedError() @property - def get_annotation_spec( - self, - ) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], - typing.Union[ - annotation_spec.AnnotationSpec, - typing.Awaitable[annotation_spec.AnnotationSpec], - ], - ]: + def get_annotation_spec(self) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec] + ]]: raise NotImplementedError() @property - def list_annotations( - self, - ) -> typing.Callable[ - [dataset_service.ListAnnotationsRequest], - typing.Union[ - dataset_service.ListAnnotationsResponse, - typing.Awaitable[dataset_service.ListAnnotationsResponse], - ], - ]: + def list_annotations(self) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse] + ]]: raise NotImplementedError() -__all__ = ("DatasetServiceTransport",) +__all__ = ( + 'DatasetServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 4dae75d109..39f0405cfa 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -46,24 +46,21 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -109,7 +106,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -117,70 +117,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -188,32 +168,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -243,12 +211,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -260,15 +229,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_dataset( - self, - ) -> Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: + def create_dataset(self) -> Callable[ + [dataset_service.CreateDatasetRequest], + operations.Operation]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -283,18 +254,18 @@ def create_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_dataset" not in self._stubs: - self._stubs["create_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", + if 'create_dataset' not in self._stubs: + self._stubs['create_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_dataset"] + return self._stubs['create_dataset'] @property - def get_dataset( - self, - ) -> Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: + def get_dataset(self) -> Callable[ + [dataset_service.GetDatasetRequest], + dataset.Dataset]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -309,18 +280,18 @@ def get_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_dataset" not in self._stubs: - self._stubs["get_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", + if 'get_dataset' not in self._stubs: + self._stubs['get_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs["get_dataset"] + return self._stubs['get_dataset'] @property - def update_dataset( - self, - ) -> Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: + def update_dataset(self) -> Callable[ + [dataset_service.UpdateDatasetRequest], + gca_dataset.Dataset]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -335,20 +306,18 @@ def update_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_dataset" not in self._stubs: - self._stubs["update_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", + if 'update_dataset' not in self._stubs: + self._stubs['update_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs["update_dataset"] + return self._stubs['update_dataset'] @property - def list_datasets( - self, - ) -> Callable[ - [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse - ]: + def list_datasets(self) -> Callable[ + [dataset_service.ListDatasetsRequest], + dataset_service.ListDatasetsResponse]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -363,18 +332,18 @@ def list_datasets( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_datasets" not in self._stubs: - self._stubs["list_datasets"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", + if 'list_datasets' not in self._stubs: + self._stubs['list_datasets'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs["list_datasets"] + return self._stubs['list_datasets'] @property - def delete_dataset( - self, - ) -> Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: + def delete_dataset(self) -> Callable[ + [dataset_service.DeleteDatasetRequest], + operations.Operation]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -389,18 +358,18 @@ def delete_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_dataset" not in self._stubs: - self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", + if 'delete_dataset' not in self._stubs: + self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_dataset"] + return self._stubs['delete_dataset'] @property - def import_data( - self, - ) -> Callable[[dataset_service.ImportDataRequest], operations.Operation]: + def import_data(self) -> Callable[ + [dataset_service.ImportDataRequest], + operations.Operation]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -415,18 +384,18 @@ def import_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "import_data" not in self._stubs: - self._stubs["import_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", + if 'import_data' not in self._stubs: + self._stubs['import_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["import_data"] + return self._stubs['import_data'] @property - def export_data( - self, - ) -> Callable[[dataset_service.ExportDataRequest], operations.Operation]: + def export_data(self) -> Callable[ + [dataset_service.ExportDataRequest], + operations.Operation]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -441,20 +410,18 @@ def export_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_data" not in self._stubs: - self._stubs["export_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", + if 'export_data' not in self._stubs: + self._stubs['export_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_data"] + return self._stubs['export_data'] @property - def list_data_items( - self, - ) -> Callable[ - [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse - ]: + def list_data_items(self) -> Callable[ + [dataset_service.ListDataItemsRequest], + dataset_service.ListDataItemsResponse]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -469,20 +436,18 @@ def list_data_items( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_items" not in self._stubs: - self._stubs["list_data_items"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", + if 'list_data_items' not in self._stubs: + self._stubs['list_data_items'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs["list_data_items"] + return self._stubs['list_data_items'] @property - def get_annotation_spec( - self, - ) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec - ]: + def get_annotation_spec(self) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + annotation_spec.AnnotationSpec]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -497,21 +462,18 @@ def get_annotation_spec( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_annotation_spec" not in self._stubs: - self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", + if 'get_annotation_spec' not in self._stubs: + self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs["get_annotation_spec"] + return self._stubs['get_annotation_spec'] @property - def list_annotations( - self, - ) -> Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse, - ]: + def list_annotations(self) -> Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -526,13 +488,15 @@ def list_annotations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_annotations" not in self._stubs: - self._stubs["list_annotations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", + if 'list_annotations' not in self._stubs: + self._stubs['list_annotations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs["list_annotations"] + return self._stubs['list_annotations'] -__all__ = ("DatasetServiceGrpcTransport",) +__all__ = ( + 'DatasetServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index 0c38b2ec38..6ed4e0785b 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import annotation_spec @@ -53,18 +53,16 @@ class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -90,24 +88,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -142,10 +138,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -154,7 +150,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -162,70 +161,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -233,18 +212,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -273,11 +242,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_dataset( - self, - ) -> Callable[ - [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] - ]: + def create_dataset(self) -> Callable[ + [dataset_service.CreateDatasetRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -292,18 +259,18 @@ def create_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_dataset" not in self._stubs: - self._stubs["create_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", + if 'create_dataset' not in self._stubs: + self._stubs['create_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_dataset"] + return self._stubs['create_dataset'] @property - def get_dataset( - self, - ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: + def get_dataset(self) -> Callable[ + [dataset_service.GetDatasetRequest], + Awaitable[dataset.Dataset]]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -318,20 +285,18 @@ def get_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_dataset" not in self._stubs: - self._stubs["get_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", + if 'get_dataset' not in self._stubs: + self._stubs['get_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs["get_dataset"] + return self._stubs['get_dataset'] @property - def update_dataset( - self, - ) -> Callable[ - [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] - ]: + def update_dataset(self) -> Callable[ + [dataset_service.UpdateDatasetRequest], + Awaitable[gca_dataset.Dataset]]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -346,21 +311,18 @@ def update_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_dataset" not in self._stubs: - self._stubs["update_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", + if 'update_dataset' not in self._stubs: + self._stubs['update_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs["update_dataset"] + return self._stubs['update_dataset'] @property - def list_datasets( - self, - ) -> Callable[ - [dataset_service.ListDatasetsRequest], - Awaitable[dataset_service.ListDatasetsResponse], - ]: + def list_datasets(self) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse]]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -375,20 +337,18 @@ def list_datasets( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_datasets" not in self._stubs: - self._stubs["list_datasets"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", + if 'list_datasets' not in self._stubs: + self._stubs['list_datasets'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs["list_datasets"] + return self._stubs['list_datasets'] @property - def delete_dataset( - self, - ) -> Callable[ - [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] - ]: + def delete_dataset(self) -> Callable[ + [dataset_service.DeleteDatasetRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -403,18 +363,18 @@ def delete_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_dataset" not in self._stubs: - self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", + if 'delete_dataset' not in self._stubs: + self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_dataset"] + return self._stubs['delete_dataset'] @property - def import_data( - self, - ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: + def import_data(self) -> Callable[ + [dataset_service.ImportDataRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -429,18 +389,18 @@ def import_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "import_data" not in self._stubs: - self._stubs["import_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", + if 'import_data' not in self._stubs: + self._stubs['import_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["import_data"] + return self._stubs['import_data'] @property - def export_data( - self, - ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: + def export_data(self) -> Callable[ + [dataset_service.ExportDataRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -455,21 +415,18 @@ def export_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_data" not in self._stubs: - self._stubs["export_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", + if 'export_data' not in self._stubs: + self._stubs['export_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_data"] + return self._stubs['export_data'] @property - def list_data_items( - self, - ) -> Callable[ - [dataset_service.ListDataItemsRequest], - Awaitable[dataset_service.ListDataItemsResponse], - ]: + def list_data_items(self) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse]]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -484,21 +441,18 @@ def list_data_items( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_items" not in self._stubs: - self._stubs["list_data_items"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", + if 'list_data_items' not in self._stubs: + self._stubs['list_data_items'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs["list_data_items"] + return self._stubs['list_data_items'] @property - def get_annotation_spec( - self, - ) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - Awaitable[annotation_spec.AnnotationSpec], - ]: + def get_annotation_spec(self) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec]]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -513,21 +467,18 @@ def get_annotation_spec( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_annotation_spec" not in self._stubs: - self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", + if 'get_annotation_spec' not in self._stubs: + self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs["get_annotation_spec"] + return self._stubs['get_annotation_spec'] @property - def list_annotations( - self, - ) -> Callable[ - [dataset_service.ListAnnotationsRequest], - Awaitable[dataset_service.ListAnnotationsResponse], - ]: + def list_annotations(self) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse]]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -542,13 +493,15 @@ def list_annotations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_annotations" not in self._stubs: - self._stubs["list_annotations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", + if 'list_annotations' not in self._stubs: + self._stubs['list_annotations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs["list_annotations"] + return self._stubs['list_annotations'] -__all__ = ("DatasetServiceGrpcAsyncIOTransport",) +__all__ = ( + 'DatasetServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py index 035a5b2388..e4f3dcfbcf 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import EndpointServiceAsyncClient __all__ = ( - "EndpointServiceClient", - "EndpointServiceAsyncClient", + 'EndpointServiceClient', + 'EndpointServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 05aa538225..daadc92c9e 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -58,34 +58,20 @@ class EndpointServiceAsyncClient: model_path = staticmethod(EndpointServiceClient.model_path) parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) - common_billing_account_path = staticmethod( - EndpointServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - EndpointServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(EndpointServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(EndpointServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - EndpointServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(EndpointServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - EndpointServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - EndpointServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(EndpointServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(EndpointServiceClient.parse_common_organization_path) common_project_path = staticmethod(EndpointServiceClient.common_project_path) - parse_common_project_path = staticmethod( - EndpointServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(EndpointServiceClient.parse_common_project_path) common_location_path = staticmethod(EndpointServiceClient.common_location_path) - parse_common_location_path = staticmethod( - EndpointServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(EndpointServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -128,18 +114,14 @@ def transport(self) -> EndpointServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) - ) + get_transport_class = functools.partial(type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -178,18 +160,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_endpoint( - self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_endpoint(self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates an Endpoint. Args: @@ -229,10 +211,8 @@ async def create_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.CreateEndpointRequest(request) @@ -255,11 +235,18 @@ async def create_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -272,15 +259,14 @@ async def create_endpoint( # Done; return the response. return response - async def get_endpoint( - self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + async def get_endpoint(self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -313,10 +299,8 @@ async def get_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.GetEndpointRequest(request) @@ -337,24 +321,30 @@ async def get_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_endpoints( - self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsAsyncPager: + async def list_endpoints(self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: r"""Lists Endpoints in a Location. Args: @@ -390,10 +380,8 @@ async def list_endpoints( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.ListEndpointsRequest(request) @@ -414,31 +402,40 @@ async def list_endpoints( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListEndpointsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def update_endpoint( - self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + async def update_endpoint(self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -478,10 +475,8 @@ async def update_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.UpdateEndpointRequest(request) @@ -504,26 +499,30 @@ async def update_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("endpoint.name", request.endpoint.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint.name', request.endpoint.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def delete_endpoint( - self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_endpoint(self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes an Endpoint. Args: @@ -569,10 +568,8 @@ async def delete_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.DeleteEndpointRequest(request) @@ -593,11 +590,18 @@ async def delete_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -610,19 +614,16 @@ async def delete_endpoint( # Done; return the response. return response - async def deploy_model( - self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[ - endpoint_service.DeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def deploy_model(self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -691,10 +692,8 @@ async def deploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.DeployModelRequest(request) @@ -720,11 +719,18 @@ async def deploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -737,19 +743,16 @@ async def deploy_model( # Done; return the response. return response - async def undeploy_model( - self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[ - endpoint_service.UndeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def undeploy_model(self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -809,10 +812,8 @@ async def undeploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = endpoint_service.UndeployModelRequest(request) @@ -838,11 +839,18 @@ async def undeploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -856,14 +864,21 @@ async def undeploy_model( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("EndpointServiceAsyncClient",) +__all__ = ( + 'EndpointServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 1fdf1e506e..78822a9489 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -56,14 +56,13 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry['grpc'] = EndpointServiceGrpcTransport + _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry["grpc"] = EndpointServiceGrpcTransport - _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -114,7 +113,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -149,8 +148,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: EndpointServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -165,104 +165,88 @@ def transport(self) -> EndpointServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -306,9 +290,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -318,9 +300,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -332,9 +312,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -346,10 +324,8 @@ def __init__( if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -368,16 +344,15 @@ def __init__( client_info=client_info, ) - def create_endpoint( - self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_endpoint(self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates an Endpoint. Args: @@ -417,10 +392,8 @@ def create_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.CreateEndpointRequest. @@ -444,11 +417,18 @@ def create_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -461,15 +441,14 @@ def create_endpoint( # Done; return the response. return response - def get_endpoint( - self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint(self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -502,10 +481,8 @@ def get_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.GetEndpointRequest. @@ -527,24 +504,30 @@ def get_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_endpoints( - self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints(self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -580,10 +563,8 @@ def list_endpoints( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.ListEndpointsRequest. @@ -605,31 +586,40 @@ def list_endpoints( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def update_endpoint( - self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint(self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -669,10 +659,8 @@ def update_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UpdateEndpointRequest. @@ -696,26 +684,30 @@ def update_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("endpoint.name", request.endpoint.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint.name', request.endpoint.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def delete_endpoint( - self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_endpoint(self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes an Endpoint. Args: @@ -761,10 +753,8 @@ def delete_endpoint( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeleteEndpointRequest. @@ -786,11 +776,18 @@ def delete_endpoint( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -803,19 +800,16 @@ def delete_endpoint( # Done; return the response. return response - def deploy_model( - self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[ - endpoint_service.DeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def deploy_model(self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -884,10 +878,8 @@ def deploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeployModelRequest. @@ -913,11 +905,18 @@ def deploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -930,19 +929,16 @@ def deploy_model( # Done; return the response. return response - def undeploy_model( - self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[ - endpoint_service.UndeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def undeploy_model(self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -1002,10 +998,8 @@ def undeploy_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UndeployModelRequest. @@ -1031,11 +1025,18 @@ def undeploy_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1049,14 +1050,21 @@ def undeploy_model( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("EndpointServiceClient",) +__all__ = ( + 'EndpointServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index db3172bcef..4261cca3fb 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -47,15 +38,12 @@ class ListEndpointsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., endpoint_service.ListEndpointsResponse], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., endpoint_service.ListEndpointsResponse], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: yield from page.endpoints def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListEndpointsAsyncPager: @@ -109,15 +97,12 @@ class ListEndpointsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -155,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py index 3d0695461d..eb2ef767fe 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] -_transport_registry["grpc"] = EndpointServiceGrpcTransport -_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = EndpointServiceGrpcTransport +_transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport __all__ = ( - "EndpointServiceTransport", - "EndpointServiceGrpcTransport", - "EndpointServiceGrpcAsyncIOTransport", + 'EndpointServiceTransport', + 'EndpointServiceGrpcTransport', + 'EndpointServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index e55589de8f..85c53f94e3 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,29 +35,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -73,64 +73,77 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_endpoint: gapic_v1.method.wrap_method( - self.create_endpoint, default_timeout=5.0, client_info=client_info, + self.create_endpoint, + default_timeout=5.0, + client_info=client_info, ), self.get_endpoint: gapic_v1.method.wrap_method( - self.get_endpoint, default_timeout=5.0, client_info=client_info, + self.get_endpoint, + default_timeout=5.0, + client_info=client_info, ), self.list_endpoints: gapic_v1.method.wrap_method( - self.list_endpoints, default_timeout=5.0, client_info=client_info, + self.list_endpoints, + default_timeout=5.0, + client_info=client_info, ), self.update_endpoint: gapic_v1.method.wrap_method( - self.update_endpoint, default_timeout=5.0, client_info=client_info, + self.update_endpoint, + default_timeout=5.0, + client_info=client_info, ), self.delete_endpoint: gapic_v1.method.wrap_method( - self.delete_endpoint, default_timeout=5.0, client_info=client_info, + self.delete_endpoint, + default_timeout=5.0, + client_info=client_info, ), self.deploy_model: gapic_v1.method.wrap_method( - self.deploy_model, default_timeout=5.0, client_info=client_info, + self.deploy_model, + default_timeout=5.0, + client_info=client_info, ), self.undeploy_model: gapic_v1.method.wrap_method( - self.undeploy_model, default_timeout=5.0, client_info=client_info, + self.undeploy_model, + default_timeout=5.0, + client_info=client_info, ), + } @property @@ -139,70 +152,69 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def create_endpoint(self) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.GetEndpointRequest], - typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], - ]: + def get_endpoint(self) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[ + endpoint.Endpoint, + typing.Awaitable[endpoint.Endpoint] + ]]: raise NotImplementedError() @property - def list_endpoints( - self, - ) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], - typing.Union[ - endpoint_service.ListEndpointsResponse, - typing.Awaitable[endpoint_service.ListEndpointsResponse], - ], - ]: + def list_endpoints(self) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse] + ]]: raise NotImplementedError() @property - def update_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], - typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], - ]: + def update_endpoint(self) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[ + gca_endpoint.Endpoint, + typing.Awaitable[gca_endpoint.Endpoint] + ]]: raise NotImplementedError() @property - def delete_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_endpoint(self) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def deploy_model( - self, - ) -> typing.Callable[ - [endpoint_service.DeployModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def deploy_model(self) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def undeploy_model( - self, - ) -> typing.Callable[ - [endpoint_service.UndeployModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def undeploy_model(self) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() -__all__ = ("EndpointServiceTransport",) +__all__ = ( + 'EndpointServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 455ed12cf4..555432fec0 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -45,24 +45,21 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -108,7 +105,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -116,70 +116,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -187,32 +167,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -242,12 +210,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -259,15 +228,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_endpoint( - self, - ) -> Callable[[endpoint_service.CreateEndpointRequest], operations.Operation]: + def create_endpoint(self) -> Callable[ + [endpoint_service.CreateEndpointRequest], + operations.Operation]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -282,18 +253,18 @@ def create_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_endpoint" not in self._stubs: - self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", + if 'create_endpoint' not in self._stubs: + self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_endpoint"] + return self._stubs['create_endpoint'] @property - def get_endpoint( - self, - ) -> Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: + def get_endpoint(self) -> Callable[ + [endpoint_service.GetEndpointRequest], + endpoint.Endpoint]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -308,20 +279,18 @@ def get_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_endpoint" not in self._stubs: - self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", + if 'get_endpoint' not in self._stubs: + self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs["get_endpoint"] + return self._stubs['get_endpoint'] @property - def list_endpoints( - self, - ) -> Callable[ - [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse - ]: + def list_endpoints(self) -> Callable[ + [endpoint_service.ListEndpointsRequest], + endpoint_service.ListEndpointsResponse]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -336,18 +305,18 @@ def list_endpoints( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_endpoints" not in self._stubs: - self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", + if 'list_endpoints' not in self._stubs: + self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs["list_endpoints"] + return self._stubs['list_endpoints'] @property - def update_endpoint( - self, - ) -> Callable[[endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint]: + def update_endpoint(self) -> Callable[ + [endpoint_service.UpdateEndpointRequest], + gca_endpoint.Endpoint]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -362,18 +331,18 @@ def update_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_endpoint" not in self._stubs: - self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", + if 'update_endpoint' not in self._stubs: + self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs["update_endpoint"] + return self._stubs['update_endpoint'] @property - def delete_endpoint( - self, - ) -> Callable[[endpoint_service.DeleteEndpointRequest], operations.Operation]: + def delete_endpoint(self) -> Callable[ + [endpoint_service.DeleteEndpointRequest], + operations.Operation]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -388,18 +357,18 @@ def delete_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_endpoint" not in self._stubs: - self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", + if 'delete_endpoint' not in self._stubs: + self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_endpoint"] + return self._stubs['delete_endpoint'] @property - def deploy_model( - self, - ) -> Callable[[endpoint_service.DeployModelRequest], operations.Operation]: + def deploy_model(self) -> Callable[ + [endpoint_service.DeployModelRequest], + operations.Operation]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -415,18 +384,18 @@ def deploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "deploy_model" not in self._stubs: - self._stubs["deploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", + if 'deploy_model' not in self._stubs: + self._stubs['deploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["deploy_model"] + return self._stubs['deploy_model'] @property - def undeploy_model( - self, - ) -> Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: + def undeploy_model(self) -> Callable[ + [endpoint_service.UndeployModelRequest], + operations.Operation]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -443,13 +412,15 @@ def undeploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "undeploy_model" not in self._stubs: - self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", + if 'undeploy_model' not in self._stubs: + self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["undeploy_model"] + return self._stubs['undeploy_model'] -__all__ = ("EndpointServiceGrpcTransport",) +__all__ = ( + 'EndpointServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index a00971a72e..1c5fe7e1f4 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import endpoint @@ -52,18 +52,16 @@ class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -89,24 +87,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -141,10 +137,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -153,7 +149,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -161,70 +160,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -232,18 +211,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -272,11 +241,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_endpoint( - self, - ) -> Callable[ - [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] - ]: + def create_endpoint(self) -> Callable[ + [endpoint_service.CreateEndpointRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -291,18 +258,18 @@ def create_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_endpoint" not in self._stubs: - self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", + if 'create_endpoint' not in self._stubs: + self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_endpoint"] + return self._stubs['create_endpoint'] @property - def get_endpoint( - self, - ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: + def get_endpoint(self) -> Callable[ + [endpoint_service.GetEndpointRequest], + Awaitable[endpoint.Endpoint]]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -317,21 +284,18 @@ def get_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_endpoint" not in self._stubs: - self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", + if 'get_endpoint' not in self._stubs: + self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs["get_endpoint"] + return self._stubs['get_endpoint'] @property - def list_endpoints( - self, - ) -> Callable[ - [endpoint_service.ListEndpointsRequest], - Awaitable[endpoint_service.ListEndpointsResponse], - ]: + def list_endpoints(self) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse]]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -346,20 +310,18 @@ def list_endpoints( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_endpoints" not in self._stubs: - self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", + if 'list_endpoints' not in self._stubs: + self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs["list_endpoints"] + return self._stubs['list_endpoints'] @property - def update_endpoint( - self, - ) -> Callable[ - [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] - ]: + def update_endpoint(self) -> Callable[ + [endpoint_service.UpdateEndpointRequest], + Awaitable[gca_endpoint.Endpoint]]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -374,20 +336,18 @@ def update_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_endpoint" not in self._stubs: - self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", + if 'update_endpoint' not in self._stubs: + self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs["update_endpoint"] + return self._stubs['update_endpoint'] @property - def delete_endpoint( - self, - ) -> Callable[ - [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] - ]: + def delete_endpoint(self) -> Callable[ + [endpoint_service.DeleteEndpointRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -402,20 +362,18 @@ def delete_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_endpoint" not in self._stubs: - self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", + if 'delete_endpoint' not in self._stubs: + self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_endpoint"] + return self._stubs['delete_endpoint'] @property - def deploy_model( - self, - ) -> Callable[ - [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] - ]: + def deploy_model(self) -> Callable[ + [endpoint_service.DeployModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -431,20 +389,18 @@ def deploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "deploy_model" not in self._stubs: - self._stubs["deploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", + if 'deploy_model' not in self._stubs: + self._stubs['deploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["deploy_model"] + return self._stubs['deploy_model'] @property - def undeploy_model( - self, - ) -> Callable[ - [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] - ]: + def undeploy_model(self) -> Callable[ + [endpoint_service.UndeployModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -461,13 +417,15 @@ def undeploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "undeploy_model" not in self._stubs: - self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", + if 'undeploy_model' not in self._stubs: + self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["undeploy_model"] + return self._stubs['undeploy_model'] -__all__ = ("EndpointServiceGrpcAsyncIOTransport",) +__all__ = ( + 'EndpointServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py index 5f157047f5..037407b714 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import JobServiceAsyncClient __all__ = ( - "JobServiceClient", - "JobServiceAsyncClient", + 'JobServiceClient', + 'JobServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 366cbf0f52..8b0e8331bb 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -21,40 +21,40 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_monitoring from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import study +from google.protobuf import duration_pb2 as duration # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore @@ -74,50 +74,38 @@ class JobServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) - parse_batch_prediction_job_path = staticmethod( - JobServiceClient.parse_batch_prediction_job_path - ) + parse_batch_prediction_job_path = staticmethod(JobServiceClient.parse_batch_prediction_job_path) custom_job_path = staticmethod(JobServiceClient.custom_job_path) parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) - parse_data_labeling_job_path = staticmethod( - JobServiceClient.parse_data_labeling_job_path - ) + parse_data_labeling_job_path = staticmethod(JobServiceClient.parse_data_labeling_job_path) dataset_path = staticmethod(JobServiceClient.dataset_path) parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) - hyperparameter_tuning_job_path = staticmethod( - JobServiceClient.hyperparameter_tuning_job_path - ) - parse_hyperparameter_tuning_job_path = staticmethod( - JobServiceClient.parse_hyperparameter_tuning_job_path - ) + endpoint_path = staticmethod(JobServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(JobServiceClient.parse_endpoint_path) + hyperparameter_tuning_job_path = staticmethod(JobServiceClient.hyperparameter_tuning_job_path) + parse_hyperparameter_tuning_job_path = staticmethod(JobServiceClient.parse_hyperparameter_tuning_job_path) model_path = staticmethod(JobServiceClient.model_path) parse_model_path = staticmethod(JobServiceClient.parse_model_path) + model_deployment_monitoring_job_path = staticmethod(JobServiceClient.model_deployment_monitoring_job_path) + parse_model_deployment_monitoring_job_path = staticmethod(JobServiceClient.parse_model_deployment_monitoring_job_path) trial_path = staticmethod(JobServiceClient.trial_path) parse_trial_path = staticmethod(JobServiceClient.parse_trial_path) - common_billing_account_path = staticmethod( - JobServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - JobServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(JobServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(JobServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(JobServiceClient.common_folder_path) parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) common_organization_path = staticmethod(JobServiceClient.common_organization_path) - parse_common_organization_path = staticmethod( - JobServiceClient.parse_common_organization_path - ) + parse_common_organization_path = staticmethod(JobServiceClient.parse_common_organization_path) common_project_path = staticmethod(JobServiceClient.common_project_path) parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) common_location_path = staticmethod(JobServiceClient.common_location_path) - parse_common_location_path = staticmethod( - JobServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(JobServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -160,18 +148,14 @@ def transport(self) -> JobServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(JobServiceClient).get_transport_class, type(JobServiceClient) - ) + get_transport_class = functools.partial(type(JobServiceClient).get_transport_class, type(JobServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -210,18 +194,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_custom_job( - self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + async def create_custom_job(self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -266,10 +250,8 @@ async def create_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateCustomJobRequest(request) @@ -292,24 +274,30 @@ async def create_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_custom_job( - self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + async def get_custom_job(self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -347,10 +335,8 @@ async def get_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetCustomJobRequest(request) @@ -371,24 +357,30 @@ async def get_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_custom_jobs( - self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsAsyncPager: + async def list_custom_jobs(self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: r"""Lists CustomJobs in a Location. Args: @@ -424,10 +416,8 @@ async def list_custom_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListCustomJobsRequest(request) @@ -448,30 +438,39 @@ async def list_custom_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListCustomJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_custom_job( - self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_custom_job(self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a CustomJob. Args: @@ -517,10 +516,8 @@ async def delete_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteCustomJobRequest(request) @@ -541,11 +538,18 @@ async def delete_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -558,15 +562,14 @@ async def delete_custom_job( # Done; return the response. return response - async def cancel_custom_job( - self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_custom_job(self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -604,10 +607,8 @@ async def cancel_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelCustomJobRequest(request) @@ -628,24 +629,28 @@ async def cancel_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_data_labeling_job( - self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_data_labeling_job(self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -685,10 +690,8 @@ async def create_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateDataLabelingJobRequest(request) @@ -711,24 +714,30 @@ async def create_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_data_labeling_job( - self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + async def get_data_labeling_job(self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -761,10 +770,8 @@ async def get_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetDataLabelingJobRequest(request) @@ -785,24 +792,30 @@ async def get_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_data_labeling_jobs( - self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsAsyncPager: + async def list_data_labeling_jobs(self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -837,10 +850,8 @@ async def list_data_labeling_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListDataLabelingJobsRequest(request) @@ -861,30 +872,39 @@ async def list_data_labeling_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataLabelingJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_data_labeling_job( - self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_data_labeling_job(self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a DataLabelingJob. Args: @@ -930,10 +950,8 @@ async def delete_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteDataLabelingJobRequest(request) @@ -954,11 +972,18 @@ async def delete_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -971,15 +996,14 @@ async def delete_data_labeling_job( # Done; return the response. return response - async def cancel_data_labeling_job( - self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_data_labeling_job(self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1006,10 +1030,8 @@ async def cancel_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelDataLabelingJobRequest(request) @@ -1030,24 +1052,28 @@ async def cancel_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_hyperparameter_tuning_job( - self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_hyperparameter_tuning_job(self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1089,10 +1115,8 @@ async def create_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateHyperparameterTuningJobRequest(request) @@ -1115,24 +1139,30 @@ async def create_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_hyperparameter_tuning_job( - self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + async def get_hyperparameter_tuning_job(self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1167,10 +1197,8 @@ async def get_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetHyperparameterTuningJobRequest(request) @@ -1191,24 +1219,30 @@ async def get_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_hyperparameter_tuning_jobs( - self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + async def list_hyperparameter_tuning_jobs(self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1244,10 +1278,8 @@ async def list_hyperparameter_tuning_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListHyperparameterTuningJobsRequest(request) @@ -1268,30 +1300,39 @@ async def list_hyperparameter_tuning_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListHyperparameterTuningJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_hyperparameter_tuning_job( - self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_hyperparameter_tuning_job(self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1337,10 +1378,8 @@ async def delete_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteHyperparameterTuningJobRequest(request) @@ -1361,11 +1400,18 @@ async def delete_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1378,15 +1424,14 @@ async def delete_hyperparameter_tuning_job( # Done; return the response. return response - async def cancel_hyperparameter_tuning_job( - self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_hyperparameter_tuning_job(self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1426,10 +1471,8 @@ async def cancel_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelHyperparameterTuningJobRequest(request) @@ -1450,24 +1493,28 @@ async def cancel_hyperparameter_tuning_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_batch_prediction_job( - self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_batch_prediction_job(self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1512,10 +1559,8 @@ async def create_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CreateBatchPredictionJobRequest(request) @@ -1538,24 +1583,30 @@ async def create_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_batch_prediction_job( - self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + async def get_batch_prediction_job(self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1592,10 +1643,8 @@ async def get_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.GetBatchPredictionJobRequest(request) @@ -1616,24 +1665,30 @@ async def get_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_batch_prediction_jobs( - self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsAsyncPager: + async def list_batch_prediction_jobs(self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1669,10 +1724,8 @@ async def list_batch_prediction_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.ListBatchPredictionJobsRequest(request) @@ -1693,30 +1746,39 @@ async def list_batch_prediction_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListBatchPredictionJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_batch_prediction_job( - self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_batch_prediction_job(self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -1763,10 +1825,8 @@ async def delete_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.DeleteBatchPredictionJobRequest(request) @@ -1787,11 +1847,18 @@ async def delete_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1804,15 +1871,14 @@ async def delete_batch_prediction_job( # Done; return the response. return response - async def cancel_batch_prediction_job( - self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_batch_prediction_job(self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -1850,10 +1916,8 @@ async def cancel_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = job_service.CancelBatchPredictionJobRequest(request) @@ -1874,23 +1938,740 @@ async def cancel_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + async def create_model_deployment_monitoring_job(self, + request: job_service.CreateModelDeploymentMonitoringJobRequest = None, + *, + parent: str = None, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + r"""Creates a ModelDeploymentMonitoringJob. It will run + periodically on a configured interval. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateModelDeploymentMonitoringJobRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.CreateModelDeploymentMonitoringJob][]. + parent (:class:`str`): + Required. The parent of the + ModelDeploymentMonitoringJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + model_deployment_monitoring_job (:class:`google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob`): + Required. The + ModelDeploymentMonitoringJob to create + + This corresponds to the ``model_deployment_monitoring_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob: + Represents a job that runs + periodically to monitor the deployed + models in an endpoint. It will analyze + the logged training & prediction data to + detect any abnormal behaviors. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, model_deployment_monitoring_job]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CreateModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if model_deployment_monitoring_job is not None: + request.model_deployment_monitoring_job = model_deployment_monitoring_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_model_deployment_monitoring_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def search_model_deployment_monitoring_stats_anomalies(self, + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, + *, + model_deployment_monitoring_job: str = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: + r"""Searches Model Monitoring Statistics generated within + a given time window. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + model_deployment_monitoring_job (:class:`str`): + Required. ModelDeploymentMonitoring Job resource name. + Format: + \`projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job} + + This corresponds to the ``model_deployment_monitoring_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (:class:`str`): + Required. The DeployedModel ID of the + [google.cloud.aiplatform.master.ModelDeploymentMonitoringObjectiveConfig.deployed_model_id]. + + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: + Response message for + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model_deployment_monitoring_job is not None: + request.model_deployment_monitoring_job = model_deployment_monitoring_job + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.search_model_deployment_monitoring_stats_anomalies, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('model_deployment_monitoring_job', request.model_deployment_monitoring_job), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_model_deployment_monitoring_job(self, + request: job_service.GetModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + r"""Gets a ModelDeploymentMonitoringJob. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetModelDeploymentMonitoringJobRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.GetModelDeploymentMonitoringJob][]. + name (:class:`str`): + Required. The resource name of the + ModelDeploymentMonitoringJob. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob: + Represents a job that runs + periodically to monitor the deployed + models in an endpoint. It will analyze + the logged training & prediction data to + detect any abnormal behaviors. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.GetModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model_deployment_monitoring_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_model_deployment_monitoring_jobs(self, + request: job_service.ListModelDeploymentMonitoringJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelDeploymentMonitoringJobsAsyncPager: + r"""Lists ModelDeploymentMonitoringJobs in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + parent (:class:`str`): + Required. The parent of the + ModelDeploymentMonitoringJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListModelDeploymentMonitoringJobsAsyncPager: + Response message for + [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.ListModelDeploymentMonitoringJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_model_deployment_monitoring_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelDeploymentMonitoringJobsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_model_deployment_monitoring_job(self, + request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, + *, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates a ModelDeploymentMonitoringJob. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateModelDeploymentMonitoringJobRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + model_deployment_monitoring_job (:class:`google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob`): + Required. The model monitoring + configuration which replaces the + resource on the server. + + This corresponds to the ``model_deployment_monitoring_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to + the resource. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob` Represents a job that runs periodically to monitor the deployed models in an + endpoint. It will analyze the logged training & + prediction data to detect any abnormal behaviors. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model_deployment_monitoring_job, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model_deployment_monitoring_job is not None: + request.model_deployment_monitoring_job = model_deployment_monitoring_job + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_model_deployment_monitoring_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('model_deployment_monitoring_job.name', request.model_deployment_monitoring_job.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + metadata_type=job_service.UpdateModelDeploymentMonitoringJobOperationMetadata, + ) + + # Done; return the response. + return response + + async def delete_model_deployment_monitoring_job(self, + request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a ModelDeploymentMonitoringJob. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteModelDeploymentMonitoringJobRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.DeleteModelDeploymentMonitoringJob][]. + name (:class:`str`): + Required. The resource name of the model monitoring job + to delete. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_model_deployment_monitoring_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def pause_model_deployment_monitoring_job(self, + request: job_service.PauseModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Pauses a ModelDeploymentMonitoringJob. If the job is running, + the server makes a best effort to cancel the job. Will mark + ``ModelDeploymentMonitoringJob.state`` + to 'PAUSED'. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.PauseModelDeploymentMonitoringJobRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.PauseModelDeploymentMonitoringJob][]. + name (:class:`str`): + Required. The resource name of the + ModelDeploymentMonitoringJob to pause. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.PauseModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.pause_model_deployment_monitoring_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def resume_model_deployment_monitoring_job(self, + request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Resumes a paused ModelDeploymentMonitoringJob. It + will start to run from next scheduled time. A deleted + ModelDeploymentMonitoringJob can't be resumed. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ResumeModelDeploymentMonitoringJobRequest`): + The request object. Request message for + [ModelDeploymentMonitoringJobService.ResumeModelDeploymentMonitoringJob][]. + name (:class:`str`): + Required. The resource name of the + ModelDeploymentMonitoringJob to resume. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.resume_model_deployment_monitoring_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("JobServiceAsyncClient",) +__all__ = ( + 'JobServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index 81fa0d786f..cb4d402b6a 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -23,42 +23,42 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_monitoring from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import study +from google.protobuf import duration_pb2 as duration # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore @@ -76,12 +76,13 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry["grpc"] = JobServiceGrpcTransport - _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport + _transport_registry['grpc'] = JobServiceGrpcTransport + _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -132,7 +133,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -167,8 +168,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: JobServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -183,194 +185,165 @@ def transport(self) -> JobServiceTransport: return self._transport @staticmethod - def batch_prediction_job_path( - project: str, location: str, batch_prediction_job: str, - ) -> str: + def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, - location=location, - batch_prediction_job=batch_prediction_job, - ) + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: + def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: """Parse a batch_prediction_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def custom_job_path(project: str, location: str, custom_job: str,) -> str: + def custom_job_path(project: str,location: str,custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str, str]: + def parse_custom_job_path(path: str) -> Dict[str,str]: """Parse a custom_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def data_labeling_job_path( - project: str, location: str, data_labeling_job: str, - ) -> str: + def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str, str]: + def parse_data_labeling_job_path(path: str) -> Dict[str,str]: """Parse a data_labeling_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def endpoint_path(project: str,location: str,endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str,str]: + """Parse a endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path( - project: str, location: str, hyperparameter_tuning_job: str, - ) -> str: + def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_deployment_monitoring_job_path(project: str,location: str,model_deployment_monitoring_job: str,) -> str: + """Return a fully-qualified model_deployment_monitoring_job string.""" + return "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format(project=project, location=location, model_deployment_monitoring_job=model_deployment_monitoring_job, ) + + @staticmethod + def parse_model_deployment_monitoring_job_path(path: str) -> Dict[str,str]: + """Parse a model_deployment_monitoring_job path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/modelDeploymentMonitoringJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str, location: str, study: str, trial: str,) -> str: + def trial_path(project: str,location: str,study: str,trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( - project=project, location=location, study=study, trial=trial, - ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) @staticmethod - def parse_trial_path(path: str) -> Dict[str, str]: + def parse_trial_path(path: str) -> Dict[str,str]: """Parse a trial path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -414,9 +387,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -426,9 +397,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -440,9 +409,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -454,10 +421,8 @@ def __init__( if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -476,16 +441,15 @@ def __init__( client_info=client_info, ) - def create_custom_job( - self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job(self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -530,10 +494,8 @@ def create_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateCustomJobRequest. @@ -557,24 +519,30 @@ def create_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_custom_job( - self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job(self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -612,10 +580,8 @@ def get_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetCustomJobRequest. @@ -637,24 +603,30 @@ def get_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_custom_jobs( - self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs(self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -690,10 +662,8 @@ def list_custom_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListCustomJobsRequest. @@ -715,30 +685,39 @@ def list_custom_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_custom_job( - self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_custom_job(self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a CustomJob. Args: @@ -784,10 +763,8 @@ def delete_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteCustomJobRequest. @@ -809,11 +786,18 @@ def delete_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -826,15 +810,14 @@ def delete_custom_job( # Done; return the response. return response - def cancel_custom_job( - self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job(self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -872,10 +855,8 @@ def cancel_custom_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelCustomJobRequest. @@ -897,24 +878,28 @@ def cancel_custom_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - def create_data_labeling_job( - self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def create_data_labeling_job(self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -954,10 +939,8 @@ def create_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateDataLabelingJobRequest. @@ -981,24 +964,30 @@ def create_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_data_labeling_job( - self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job(self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -1031,10 +1020,8 @@ def get_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetDataLabelingJobRequest. @@ -1056,24 +1043,30 @@ def get_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_data_labeling_jobs( - self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs(self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -1108,10 +1101,8 @@ def list_data_labeling_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListDataLabelingJobsRequest. @@ -1133,30 +1124,39 @@ def list_data_labeling_jobs( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job( - self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_data_labeling_job(self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1202,10 +1202,8 @@ def delete_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteDataLabelingJobRequest. @@ -1227,11 +1225,18 @@ def delete_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1244,15 +1249,14 @@ def delete_data_labeling_job( # Done; return the response. return response - def cancel_data_labeling_job( - self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job(self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1279,10 +1283,8 @@ def cancel_data_labeling_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelDataLabelingJobRequest. @@ -1304,24 +1306,28 @@ def cancel_data_labeling_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - def create_hyperparameter_tuning_job( - self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def create_hyperparameter_tuning_job(self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1363,10 +1369,8 @@ def create_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateHyperparameterTuningJobRequest. @@ -1385,31 +1389,35 @@ def create_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.create_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_hyperparameter_tuning_job( - self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job(self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1444,10 +1452,8 @@ def get_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetHyperparameterTuningJobRequest. @@ -1464,31 +1470,35 @@ def get_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.get_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_hyperparameter_tuning_jobs( - self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs(self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1524,10 +1534,8 @@ def list_hyperparameter_tuning_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListHyperparameterTuningJobsRequest. @@ -1544,37 +1552,44 @@ def list_hyperparameter_tuning_jobs( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_hyperparameter_tuning_jobs - ] + rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job( - self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_hyperparameter_tuning_job(self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1620,10 +1635,8 @@ def delete_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteHyperparameterTuningJobRequest. @@ -1640,18 +1653,23 @@ def delete_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.delete_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1664,15 +1682,14 @@ def delete_hyperparameter_tuning_job( # Done; return the response. return response - def cancel_hyperparameter_tuning_job( - self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job(self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1712,10 +1729,8 @@ def cancel_hyperparameter_tuning_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelHyperparameterTuningJobRequest. @@ -1732,31 +1747,33 @@ def cancel_hyperparameter_tuning_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.cancel_hyperparameter_tuning_job - ] + rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - def create_batch_prediction_job( - self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def create_batch_prediction_job(self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1801,10 +1818,8 @@ def create_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateBatchPredictionJobRequest. @@ -1823,31 +1838,35 @@ def create_batch_prediction_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.create_batch_prediction_job - ] + rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_batch_prediction_job( - self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job(self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1884,10 +1903,8 @@ def get_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.GetBatchPredictionJobRequest. @@ -1909,24 +1926,30 @@ def get_batch_prediction_job( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_batch_prediction_jobs( - self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs(self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1962,10 +1985,8 @@ def list_batch_prediction_jobs( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.ListBatchPredictionJobsRequest. @@ -1982,37 +2003,44 @@ def list_batch_prediction_jobs( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_batch_prediction_jobs - ] + rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job( - self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_batch_prediction_job(self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2059,10 +2087,8 @@ def delete_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteBatchPredictionJobRequest. @@ -2079,18 +2105,23 @@ def delete_batch_prediction_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.delete_batch_prediction_job - ] + rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -2103,15 +2134,14 @@ def delete_batch_prediction_job( # Done; return the response. return response - def cancel_batch_prediction_job( - self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job(self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -2149,10 +2179,8 @@ def cancel_batch_prediction_job( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelBatchPredictionJobRequest. @@ -2169,30 +2197,753 @@ def cancel_batch_prediction_job( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.cancel_batch_prediction_job - ] + rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def create_model_deployment_monitoring_job(self, + request: job_service.CreateModelDeploymentMonitoringJobRequest = None, + *, + parent: str = None, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + r"""Creates a ModelDeploymentMonitoringJob. It will run + periodically on a configured interval. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateModelDeploymentMonitoringJobRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.CreateModelDeploymentMonitoringJob][]. + parent (str): + Required. The parent of the + ModelDeploymentMonitoringJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + model_deployment_monitoring_job (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob): + Required. The + ModelDeploymentMonitoringJob to create + + This corresponds to the ``model_deployment_monitoring_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob: + Represents a job that runs + periodically to monitor the deployed + models in an endpoint. It will analyze + the logged training & prediction data to + detect any abnormal behaviors. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, model_deployment_monitoring_job]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateModelDeploymentMonitoringJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateModelDeploymentMonitoringJobRequest): + request = job_service.CreateModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if model_deployment_monitoring_job is not None: + request.model_deployment_monitoring_job = model_deployment_monitoring_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_model_deployment_monitoring_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def search_model_deployment_monitoring_stats_anomalies(self, + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, + *, + model_deployment_monitoring_job: str = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager: + r"""Searches Model Monitoring Statistics generated within + a given time window. + + Args: + request (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + model_deployment_monitoring_job (str): + Required. ModelDeploymentMonitoring Job resource name. + Format: + \`projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job} + + This corresponds to the ``model_deployment_monitoring_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (str): + Required. The DeployedModel ID of the + [google.cloud.aiplatform.master.ModelDeploymentMonitoringObjectiveConfig.deployed_model_id]. + + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager: + Response message for + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest): + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model_deployment_monitoring_job is not None: + request.model_deployment_monitoring_job = model_deployment_monitoring_job + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.search_model_deployment_monitoring_stats_anomalies] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('model_deployment_monitoring_job', request.model_deployment_monitoring_job), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_model_deployment_monitoring_job(self, + request: job_service.GetModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + r"""Gets a ModelDeploymentMonitoringJob. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetModelDeploymentMonitoringJobRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.GetModelDeploymentMonitoringJob][]. + name (str): + Required. The resource name of the + ModelDeploymentMonitoringJob. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob: + Represents a job that runs + periodically to monitor the deployed + models in an endpoint. It will analyze + the logged training & prediction data to + detect any abnormal behaviors. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetModelDeploymentMonitoringJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetModelDeploymentMonitoringJobRequest): + request = job_service.GetModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_model_deployment_monitoring_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_model_deployment_monitoring_jobs(self, + request: job_service.ListModelDeploymentMonitoringJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelDeploymentMonitoringJobsPager: + r"""Lists ModelDeploymentMonitoringJobs in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + parent (str): + Required. The parent of the + ModelDeploymentMonitoringJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListModelDeploymentMonitoringJobsPager: + Response message for + [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListModelDeploymentMonitoringJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListModelDeploymentMonitoringJobsRequest): + request = job_service.ListModelDeploymentMonitoringJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_model_deployment_monitoring_jobs] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListModelDeploymentMonitoringJobsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_model_deployment_monitoring_job(self, + request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, + *, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Updates a ModelDeploymentMonitoringJob. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateModelDeploymentMonitoringJobRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + model_deployment_monitoring_job (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob): + Required. The model monitoring + configuration which replaces the + resource on the server. + + This corresponds to the ``model_deployment_monitoring_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to + the resource. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob` Represents a job that runs periodically to monitor the deployed models in an + endpoint. It will analyze the logged training & + prediction data to detect any abnormal behaviors. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model_deployment_monitoring_job, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.UpdateModelDeploymentMonitoringJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.UpdateModelDeploymentMonitoringJobRequest): + request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model_deployment_monitoring_job is not None: + request.model_deployment_monitoring_job = model_deployment_monitoring_job + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_model_deployment_monitoring_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('model_deployment_monitoring_job.name', request.model_deployment_monitoring_job.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + metadata_type=job_service.UpdateModelDeploymentMonitoringJobOperationMetadata, + ) + + # Done; return the response. + return response + + def delete_model_deployment_monitoring_job(self, + request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a ModelDeploymentMonitoringJob. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteModelDeploymentMonitoringJobRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.DeleteModelDeploymentMonitoringJob][]. + name (str): + Required. The resource name of the model monitoring job + to delete. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteModelDeploymentMonitoringJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteModelDeploymentMonitoringJobRequest): + request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_model_deployment_monitoring_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def pause_model_deployment_monitoring_job(self, + request: job_service.PauseModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Pauses a ModelDeploymentMonitoringJob. If the job is running, + the server makes a best effort to cancel the job. Will mark + ``ModelDeploymentMonitoringJob.state`` + to 'PAUSED'. + + Args: + request (google.cloud.aiplatform_v1beta1.types.PauseModelDeploymentMonitoringJobRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.PauseModelDeploymentMonitoringJob][]. + name (str): + Required. The resource name of the + ModelDeploymentMonitoringJob to pause. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.PauseModelDeploymentMonitoringJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.PauseModelDeploymentMonitoringJobRequest): + request = job_service.PauseModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.pause_model_deployment_monitoring_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def resume_model_deployment_monitoring_job(self, + request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Resumes a paused ModelDeploymentMonitoringJob. It + will start to run from next scheduled time. A deleted + ModelDeploymentMonitoringJob can't be resumed. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ResumeModelDeploymentMonitoringJobRequest): + The request object. Request message for + [ModelDeploymentMonitoringJobService.ResumeModelDeploymentMonitoringJob][]. + name (str): + Required. The resource name of the + ModelDeploymentMonitoringJob to resume. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ResumeModelDeploymentMonitoringJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ResumeModelDeploymentMonitoringJobRequest): + request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.resume_model_deployment_monitoring_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("JobServiceClient",) +__all__ = ( + 'JobServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 6c3da33d0a..85cb433f67 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -15,22 +15,15 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job class ListCustomJobsPager: @@ -50,15 +43,12 @@ class ListCustomJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListCustomJobsResponse], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListCustomJobsResponse], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -92,7 +82,7 @@ def __iter__(self) -> Iterable[custom_job.CustomJob]: yield from page.custom_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListCustomJobsAsyncPager: @@ -112,15 +102,12 @@ class ListCustomJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -158,7 +145,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataLabelingJobsPager: @@ -178,15 +165,12 @@ class ListDataLabelingJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListDataLabelingJobsResponse], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListDataLabelingJobsResponse], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -220,7 +204,7 @@ def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: yield from page.data_labeling_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataLabelingJobsAsyncPager: @@ -240,15 +224,12 @@ class ListDataLabelingJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -286,7 +267,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsPager: @@ -306,15 +287,12 @@ class ListHyperparameterTuningJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -348,7 +326,7 @@ def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob yield from page.hyperparameter_tuning_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsAsyncPager: @@ -368,17 +346,12 @@ class ListHyperparameterTuningJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] - ], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListHyperparameterTuningJobsResponse]], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -400,18 +373,14 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + async def pages(self) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__( - self, - ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + def __aiter__(self) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: async def async_generator(): async for page in self.pages: for response in page.hyperparameter_tuning_jobs: @@ -420,7 +389,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListBatchPredictionJobsPager: @@ -440,15 +409,12 @@ class ListBatchPredictionJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., job_service.ListBatchPredictionJobsResponse], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., job_service.ListBatchPredictionJobsResponse], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -482,7 +448,7 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: yield from page.batch_prediction_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListBatchPredictionJobsAsyncPager: @@ -502,15 +468,12 @@ class ListBatchPredictionJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -548,4 +511,248 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class SearchModelDeploymentMonitoringStatsAnomaliesPager: + """A pager for iterating through ``search_model_deployment_monitoring_stats_anomalies`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``monitoring_stats`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``SearchModelDeploymentMonitoringStatsAnomalies`` requests and continue to iterate + through the ``monitoring_stats`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse], + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, + response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies]: + for page in self.pages: + yield from page.monitoring_stats + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: + """A pager for iterating through ``search_model_deployment_monitoring_stats_anomalies`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``monitoring_stats`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``SearchModelDeploymentMonitoringStatsAnomalies`` requests and continue to iterate + through the ``monitoring_stats`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]], + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, + response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies]: + async def async_generator(): + async for page in self.pages: + for response in page.monitoring_stats: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListModelDeploymentMonitoringJobsPager: + """A pager for iterating through ``list_model_deployment_monitoring_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``model_deployment_monitoring_jobs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListModelDeploymentMonitoringJobs`` requests and continue to iterate + through the ``model_deployment_monitoring_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., job_service.ListModelDeploymentMonitoringJobsResponse], + request: job_service.ListModelDeploymentMonitoringJobsRequest, + response: job_service.ListModelDeploymentMonitoringJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListModelDeploymentMonitoringJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[job_service.ListModelDeploymentMonitoringJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + for page in self.pages: + yield from page.model_deployment_monitoring_jobs + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListModelDeploymentMonitoringJobsAsyncPager: + """A pager for iterating through ``list_model_deployment_monitoring_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``model_deployment_monitoring_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModelDeploymentMonitoringJobs`` requests and continue to iterate + through the ``model_deployment_monitoring_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse]], + request: job_service.ListModelDeploymentMonitoringJobsRequest, + response: job_service.ListModelDeploymentMonitoringJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListModelDeploymentMonitoringJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListModelDeploymentMonitoringJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + async def async_generator(): + async for page in self.pages: + for response in page.model_deployment_monitoring_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py index 349bfbcdea..8b5de46a7e 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] -_transport_registry["grpc"] = JobServiceGrpcTransport -_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = JobServiceGrpcTransport +_transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport __all__ = ( - "JobServiceTransport", - "JobServiceGrpcTransport", - "JobServiceGrpcAsyncIOTransport", + 'JobServiceTransport', + 'JobServiceGrpcTransport', + 'JobServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index 3d1f0be59b..8ec1ad88c2 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -21,26 +21,22 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -48,29 +44,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -86,57 +82,65 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_custom_job: gapic_v1.method.wrap_method( - self.create_custom_job, default_timeout=5.0, client_info=client_info, + self.create_custom_job, + default_timeout=5.0, + client_info=client_info, ), self.get_custom_job: gapic_v1.method.wrap_method( - self.get_custom_job, default_timeout=5.0, client_info=client_info, + self.get_custom_job, + default_timeout=5.0, + client_info=client_info, ), self.list_custom_jobs: gapic_v1.method.wrap_method( - self.list_custom_jobs, default_timeout=5.0, client_info=client_info, + self.list_custom_jobs, + default_timeout=5.0, + client_info=client_info, ), self.delete_custom_job: gapic_v1.method.wrap_method( - self.delete_custom_job, default_timeout=5.0, client_info=client_info, + self.delete_custom_job, + default_timeout=5.0, + client_info=client_info, ), self.cancel_custom_job: gapic_v1.method.wrap_method( - self.cancel_custom_job, default_timeout=5.0, client_info=client_info, + self.cancel_custom_job, + default_timeout=5.0, + client_info=client_info, ), self.create_data_labeling_job: gapic_v1.method.wrap_method( self.create_data_labeling_job, @@ -213,6 +217,47 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), + self.create_model_deployment_monitoring_job: gapic_v1.method.wrap_method( + self.create_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.search_model_deployment_monitoring_stats_anomalies: gapic_v1.method.wrap_method( + self.search_model_deployment_monitoring_stats_anomalies, + default_timeout=None, + client_info=client_info, + ), + self.get_model_deployment_monitoring_job: gapic_v1.method.wrap_method( + self.get_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.list_model_deployment_monitoring_jobs: gapic_v1.method.wrap_method( + self.list_model_deployment_monitoring_jobs, + default_timeout=None, + client_info=client_info, + ), + self.update_model_deployment_monitoring_job: gapic_v1.method.wrap_method( + self.update_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.delete_model_deployment_monitoring_job: gapic_v1.method.wrap_method( + self.delete_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.pause_model_deployment_monitoring_job: gapic_v1.method.wrap_method( + self.pause_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.resume_model_deployment_monitoring_job: gapic_v1.method.wrap_method( + self.resume_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + } @property @@ -221,216 +266,258 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_custom_job( - self, - ) -> typing.Callable[ - [job_service.CreateCustomJobRequest], - typing.Union[ - gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] - ], - ]: + def create_custom_job(self) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, + typing.Awaitable[gca_custom_job.CustomJob] + ]]: raise NotImplementedError() @property - def get_custom_job( - self, - ) -> typing.Callable[ - [job_service.GetCustomJobRequest], - typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], - ]: + def get_custom_job(self) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[ + custom_job.CustomJob, + typing.Awaitable[custom_job.CustomJob] + ]]: raise NotImplementedError() @property - def list_custom_jobs( - self, - ) -> typing.Callable[ - [job_service.ListCustomJobsRequest], - typing.Union[ - job_service.ListCustomJobsResponse, - typing.Awaitable[job_service.ListCustomJobsResponse], - ], - ]: + def list_custom_jobs(self) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse] + ]]: raise NotImplementedError() @property - def delete_custom_job( - self, - ) -> typing.Callable[ - [job_service.DeleteCustomJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_custom_job(self) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_custom_job( - self, - ) -> typing.Callable[ - [job_service.CancelCustomJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_custom_job(self) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def create_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.CreateDataLabelingJobRequest], - typing.Union[ - gca_data_labeling_job.DataLabelingJob, - typing.Awaitable[gca_data_labeling_job.DataLabelingJob], - ], - ]: + def create_data_labeling_job(self) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob] + ]]: raise NotImplementedError() @property - def get_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], - typing.Union[ - data_labeling_job.DataLabelingJob, - typing.Awaitable[data_labeling_job.DataLabelingJob], - ], - ]: + def get_data_labeling_job(self) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob] + ]]: raise NotImplementedError() @property - def list_data_labeling_jobs( - self, - ) -> typing.Callable[ - [job_service.ListDataLabelingJobsRequest], - typing.Union[ - job_service.ListDataLabelingJobsResponse, - typing.Awaitable[job_service.ListDataLabelingJobsResponse], - ], - ]: + def list_data_labeling_jobs(self) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse] + ]]: raise NotImplementedError() @property - def delete_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_data_labeling_job(self) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.CancelDataLabelingJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_data_labeling_job(self) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def create_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - typing.Union[ - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], - ], - ]: + def create_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob] + ]]: raise NotImplementedError() @property - def get_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.GetHyperparameterTuningJobRequest], - typing.Union[ - hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], - ], - ]: + def get_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob] + ]]: raise NotImplementedError() @property - def list_hyperparameter_tuning_jobs( - self, - ) -> typing.Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - typing.Union[ - job_service.ListHyperparameterTuningJobsResponse, - typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], - ], - ]: + def list_hyperparameter_tuning_jobs(self) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ]]: raise NotImplementedError() @property - def delete_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def create_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.CreateBatchPredictionJobRequest], - typing.Union[ - gca_batch_prediction_job.BatchPredictionJob, - typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], - ], - ]: + def create_batch_prediction_job(self) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob] + ]]: raise NotImplementedError() @property - def get_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.GetBatchPredictionJobRequest], - typing.Union[ - batch_prediction_job.BatchPredictionJob, - typing.Awaitable[batch_prediction_job.BatchPredictionJob], - ], - ]: + def get_batch_prediction_job(self) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob] + ]]: raise NotImplementedError() @property - def list_batch_prediction_jobs( - self, - ) -> typing.Callable[ - [job_service.ListBatchPredictionJobsRequest], - typing.Union[ - job_service.ListBatchPredictionJobsResponse, - typing.Awaitable[job_service.ListBatchPredictionJobsResponse], - ], - ]: + def list_batch_prediction_jobs(self) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse] + ]]: raise NotImplementedError() @property - def delete_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_batch_prediction_job(self) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.CancelBatchPredictionJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_batch_prediction_job(self) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() + @property + def create_model_deployment_monitoring_job(self) -> typing.Callable[ + [job_service.CreateModelDeploymentMonitoringJobRequest], + typing.Union[ + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + typing.Awaitable[gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob] + ]]: + raise NotImplementedError() + + @property + def search_model_deployment_monitoring_stats_anomalies(self) -> typing.Callable[ + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + typing.Union[ + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, + typing.Awaitable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse] + ]]: + raise NotImplementedError() + + @property + def get_model_deployment_monitoring_job(self) -> typing.Callable[ + [job_service.GetModelDeploymentMonitoringJobRequest], + typing.Union[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + typing.Awaitable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob] + ]]: + raise NotImplementedError() + + @property + def list_model_deployment_monitoring_jobs(self) -> typing.Callable[ + [job_service.ListModelDeploymentMonitoringJobsRequest], + typing.Union[ + job_service.ListModelDeploymentMonitoringJobsResponse, + typing.Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse] + ]]: + raise NotImplementedError() -__all__ = ("JobServiceTransport",) + @property + def update_model_deployment_monitoring_job(self) -> typing.Callable[ + [job_service.UpdateModelDeploymentMonitoringJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def delete_model_deployment_monitoring_job(self) -> typing.Callable[ + [job_service.DeleteModelDeploymentMonitoringJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def pause_model_deployment_monitoring_job(self) -> typing.Callable[ + [job_service.PauseModelDeploymentMonitoringJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: + raise NotImplementedError() + + @property + def resume_model_deployment_monitoring_job(self) -> typing.Callable[ + [job_service.ResumeModelDeploymentMonitoringJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'JobServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index 763f510e5b..61b67d0f98 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -18,30 +18,26 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -60,24 +56,21 @@ class JobServiceGrpcTransport(JobServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -123,7 +116,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -131,70 +127,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -202,32 +178,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -257,12 +221,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -274,15 +239,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_custom_job( - self, - ) -> Callable[[job_service.CreateCustomJobRequest], gca_custom_job.CustomJob]: + def create_custom_job(self) -> Callable[ + [job_service.CreateCustomJobRequest], + gca_custom_job.CustomJob]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -298,18 +265,18 @@ def create_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_custom_job" not in self._stubs: - self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", + if 'create_custom_job' not in self._stubs: + self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs["create_custom_job"] + return self._stubs['create_custom_job'] @property - def get_custom_job( - self, - ) -> Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: + def get_custom_job(self) -> Callable[ + [job_service.GetCustomJobRequest], + custom_job.CustomJob]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -324,20 +291,18 @@ def get_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_custom_job" not in self._stubs: - self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", + if 'get_custom_job' not in self._stubs: + self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs["get_custom_job"] + return self._stubs['get_custom_job'] @property - def list_custom_jobs( - self, - ) -> Callable[ - [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse - ]: + def list_custom_jobs(self) -> Callable[ + [job_service.ListCustomJobsRequest], + job_service.ListCustomJobsResponse]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -352,18 +317,18 @@ def list_custom_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_custom_jobs" not in self._stubs: - self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", + if 'list_custom_jobs' not in self._stubs: + self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs["list_custom_jobs"] + return self._stubs['list_custom_jobs'] @property - def delete_custom_job( - self, - ) -> Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: + def delete_custom_job(self) -> Callable[ + [job_service.DeleteCustomJobRequest], + operations.Operation]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -378,18 +343,18 @@ def delete_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_custom_job" not in self._stubs: - self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", + if 'delete_custom_job' not in self._stubs: + self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_custom_job"] + return self._stubs['delete_custom_job'] @property - def cancel_custom_job( - self, - ) -> Callable[[job_service.CancelCustomJobRequest], empty.Empty]: + def cancel_custom_job(self) -> Callable[ + [job_service.CancelCustomJobRequest], + empty.Empty]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -416,21 +381,18 @@ def cancel_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_custom_job" not in self._stubs: - self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", + if 'cancel_custom_job' not in self._stubs: + self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_custom_job"] + return self._stubs['cancel_custom_job'] @property - def create_data_labeling_job( - self, - ) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob, - ]: + def create_data_labeling_job(self) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + gca_data_labeling_job.DataLabelingJob]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -445,20 +407,18 @@ def create_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_data_labeling_job" not in self._stubs: - self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", + if 'create_data_labeling_job' not in self._stubs: + self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["create_data_labeling_job"] + return self._stubs['create_data_labeling_job'] @property - def get_data_labeling_job( - self, - ) -> Callable[ - [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob - ]: + def get_data_labeling_job(self) -> Callable[ + [job_service.GetDataLabelingJobRequest], + data_labeling_job.DataLabelingJob]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -473,21 +433,18 @@ def get_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_data_labeling_job" not in self._stubs: - self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", + if 'get_data_labeling_job' not in self._stubs: + self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["get_data_labeling_job"] + return self._stubs['get_data_labeling_job'] @property - def list_data_labeling_jobs( - self, - ) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse, - ]: + def list_data_labeling_jobs(self) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -502,18 +459,18 @@ def list_data_labeling_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_labeling_jobs" not in self._stubs: - self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", + if 'list_data_labeling_jobs' not in self._stubs: + self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs["list_data_labeling_jobs"] + return self._stubs['list_data_labeling_jobs'] @property - def delete_data_labeling_job( - self, - ) -> Callable[[job_service.DeleteDataLabelingJobRequest], operations.Operation]: + def delete_data_labeling_job(self) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], + operations.Operation]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -528,18 +485,18 @@ def delete_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_data_labeling_job" not in self._stubs: - self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", + if 'delete_data_labeling_job' not in self._stubs: + self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_data_labeling_job"] + return self._stubs['delete_data_labeling_job'] @property - def cancel_data_labeling_job( - self, - ) -> Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: + def cancel_data_labeling_job(self) -> Callable[ + [job_service.CancelDataLabelingJobRequest], + empty.Empty]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -555,21 +512,18 @@ def cancel_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_data_labeling_job" not in self._stubs: - self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", + if 'cancel_data_labeling_job' not in self._stubs: + self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_data_labeling_job"] + return self._stubs['cancel_data_labeling_job'] @property - def create_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - ]: + def create_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + gca_hyperparameter_tuning_job.HyperparameterTuningJob]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -585,23 +539,18 @@ def create_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "create_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", + if 'create_hyperparameter_tuning_job' not in self._stubs: + self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["create_hyperparameter_tuning_job"] + return self._stubs['create_hyperparameter_tuning_job'] @property - def get_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob, - ]: + def get_hyperparameter_tuning_job(self) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + hyperparameter_tuning_job.HyperparameterTuningJob]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -616,23 +565,18 @@ def get_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "get_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", + if 'get_hyperparameter_tuning_job' not in self._stubs: + self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["get_hyperparameter_tuning_job"] + return self._stubs['get_hyperparameter_tuning_job'] @property - def list_hyperparameter_tuning_jobs( - self, - ) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse, - ]: + def list_hyperparameter_tuning_jobs(self) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -648,22 +592,18 @@ def list_hyperparameter_tuning_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_hyperparameter_tuning_jobs" not in self._stubs: - self._stubs[ - "list_hyperparameter_tuning_jobs" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", + if 'list_hyperparameter_tuning_jobs' not in self._stubs: + self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs["list_hyperparameter_tuning_jobs"] + return self._stubs['list_hyperparameter_tuning_jobs'] @property - def delete_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation - ]: + def delete_hyperparameter_tuning_job(self) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + operations.Operation]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -679,20 +619,18 @@ def delete_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "delete_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", + if 'delete_hyperparameter_tuning_job' not in self._stubs: + self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_hyperparameter_tuning_job"] + return self._stubs['delete_hyperparameter_tuning_job'] @property - def cancel_hyperparameter_tuning_job( - self, - ) -> Callable[[job_service.CancelHyperparameterTuningJobRequest], empty.Empty]: + def cancel_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + empty.Empty]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -721,23 +659,18 @@ def cancel_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "cancel_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", + if 'cancel_hyperparameter_tuning_job' not in self._stubs: + self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_hyperparameter_tuning_job"] + return self._stubs['cancel_hyperparameter_tuning_job'] @property - def create_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob, - ]: + def create_batch_prediction_job(self) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + gca_batch_prediction_job.BatchPredictionJob]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -753,21 +686,18 @@ def create_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_batch_prediction_job" not in self._stubs: - self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", + if 'create_batch_prediction_job' not in self._stubs: + self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["create_batch_prediction_job"] + return self._stubs['create_batch_prediction_job'] @property - def get_batch_prediction_job( - self, - ) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob, - ]: + def get_batch_prediction_job(self) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + batch_prediction_job.BatchPredictionJob]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -782,21 +712,18 @@ def get_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_batch_prediction_job" not in self._stubs: - self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", + if 'get_batch_prediction_job' not in self._stubs: + self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["get_batch_prediction_job"] + return self._stubs['get_batch_prediction_job'] @property - def list_batch_prediction_jobs( - self, - ) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse, - ]: + def list_batch_prediction_jobs(self) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -811,18 +738,18 @@ def list_batch_prediction_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_batch_prediction_jobs" not in self._stubs: - self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", + if 'list_batch_prediction_jobs' not in self._stubs: + self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs["list_batch_prediction_jobs"] + return self._stubs['list_batch_prediction_jobs'] @property - def delete_batch_prediction_job( - self, - ) -> Callable[[job_service.DeleteBatchPredictionJobRequest], operations.Operation]: + def delete_batch_prediction_job(self) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], + operations.Operation]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -838,18 +765,18 @@ def delete_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_batch_prediction_job" not in self._stubs: - self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", + if 'delete_batch_prediction_job' not in self._stubs: + self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_batch_prediction_job"] + return self._stubs['delete_batch_prediction_job'] @property - def cancel_batch_prediction_job( - self, - ) -> Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: + def cancel_batch_prediction_job(self) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], + empty.Empty]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -875,13 +802,238 @@ def cancel_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_batch_prediction_job" not in self._stubs: - self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", + if 'cancel_batch_prediction_job' not in self._stubs: + self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_batch_prediction_job"] + return self._stubs['cancel_batch_prediction_job'] + + @property + def create_model_deployment_monitoring_job(self) -> Callable[ + [job_service.CreateModelDeploymentMonitoringJobRequest], + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + r"""Return a callable for the create model deployment + monitoring job method over gRPC. + + Creates a ModelDeploymentMonitoringJob. It will run + periodically on a configured interval. + + Returns: + Callable[[~.CreateModelDeploymentMonitoringJobRequest], + ~.ModelDeploymentMonitoringJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_model_deployment_monitoring_job' not in self._stubs: + self._stubs['create_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateModelDeploymentMonitoringJob', + request_serializer=job_service.CreateModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, + ) + return self._stubs['create_model_deployment_monitoring_job'] + + @property + def search_model_deployment_monitoring_stats_anomalies(self) -> Callable[ + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: + r"""Return a callable for the search model deployment + monitoring stats anomalies method over gRPC. + + Searches Model Monitoring Statistics generated within + a given time window. + + Returns: + Callable[[~.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + ~.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'search_model_deployment_monitoring_stats_anomalies' not in self._stubs: + self._stubs['search_model_deployment_monitoring_stats_anomalies'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/SearchModelDeploymentMonitoringStatsAnomalies', + request_serializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest.serialize, + response_deserializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse.deserialize, + ) + return self._stubs['search_model_deployment_monitoring_stats_anomalies'] + + @property + def get_model_deployment_monitoring_job(self) -> Callable[ + [job_service.GetModelDeploymentMonitoringJobRequest], + model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + r"""Return a callable for the get model deployment + monitoring job method over gRPC. + + Gets a ModelDeploymentMonitoringJob. + + Returns: + Callable[[~.GetModelDeploymentMonitoringJobRequest], + ~.ModelDeploymentMonitoringJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model_deployment_monitoring_job' not in self._stubs: + self._stubs['get_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetModelDeploymentMonitoringJob', + request_serializer=job_service.GetModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, + ) + return self._stubs['get_model_deployment_monitoring_job'] + + @property + def list_model_deployment_monitoring_jobs(self) -> Callable[ + [job_service.ListModelDeploymentMonitoringJobsRequest], + job_service.ListModelDeploymentMonitoringJobsResponse]: + r"""Return a callable for the list model deployment + monitoring jobs method over gRPC. + + Lists ModelDeploymentMonitoringJobs in a Location. + + Returns: + Callable[[~.ListModelDeploymentMonitoringJobsRequest], + ~.ListModelDeploymentMonitoringJobsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_model_deployment_monitoring_jobs' not in self._stubs: + self._stubs['list_model_deployment_monitoring_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListModelDeploymentMonitoringJobs', + request_serializer=job_service.ListModelDeploymentMonitoringJobsRequest.serialize, + response_deserializer=job_service.ListModelDeploymentMonitoringJobsResponse.deserialize, + ) + return self._stubs['list_model_deployment_monitoring_jobs'] + + @property + def update_model_deployment_monitoring_job(self) -> Callable[ + [job_service.UpdateModelDeploymentMonitoringJobRequest], + operations.Operation]: + r"""Return a callable for the update model deployment + monitoring job method over gRPC. + + Updates a ModelDeploymentMonitoringJob. + + Returns: + Callable[[~.UpdateModelDeploymentMonitoringJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_model_deployment_monitoring_job' not in self._stubs: + self._stubs['update_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/UpdateModelDeploymentMonitoringJob', + request_serializer=job_service.UpdateModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['update_model_deployment_monitoring_job'] + + @property + def delete_model_deployment_monitoring_job(self) -> Callable[ + [job_service.DeleteModelDeploymentMonitoringJobRequest], + operations.Operation]: + r"""Return a callable for the delete model deployment + monitoring job method over gRPC. + + Deletes a ModelDeploymentMonitoringJob. + + Returns: + Callable[[~.DeleteModelDeploymentMonitoringJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_model_deployment_monitoring_job' not in self._stubs: + self._stubs['delete_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteModelDeploymentMonitoringJob', + request_serializer=job_service.DeleteModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_model_deployment_monitoring_job'] + + @property + def pause_model_deployment_monitoring_job(self) -> Callable[ + [job_service.PauseModelDeploymentMonitoringJobRequest], + empty.Empty]: + r"""Return a callable for the pause model deployment + monitoring job method over gRPC. + + Pauses a ModelDeploymentMonitoringJob. If the job is running, + the server makes a best effort to cancel the job. Will mark + ``ModelDeploymentMonitoringJob.state`` + to 'PAUSED'. + + Returns: + Callable[[~.PauseModelDeploymentMonitoringJobRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'pause_model_deployment_monitoring_job' not in self._stubs: + self._stubs['pause_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/PauseModelDeploymentMonitoringJob', + request_serializer=job_service.PauseModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['pause_model_deployment_monitoring_job'] + + @property + def resume_model_deployment_monitoring_job(self) -> Callable[ + [job_service.ResumeModelDeploymentMonitoringJobRequest], + empty.Empty]: + r"""Return a callable for the resume model deployment + monitoring job method over gRPC. + + Resumes a paused ModelDeploymentMonitoringJob. It + will start to run from next scheduled time. A deleted + ModelDeploymentMonitoringJob can't be resumed. + + Returns: + Callable[[~.ResumeModelDeploymentMonitoringJobRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'resume_model_deployment_monitoring_job' not in self._stubs: + self._stubs['resume_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ResumeModelDeploymentMonitoringJob', + request_serializer=job_service.ResumeModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['resume_model_deployment_monitoring_job'] -__all__ = ("JobServiceGrpcTransport",) +__all__ = ( + 'JobServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index 07655ba262..3cd0904008 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -18,31 +18,27 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -67,18 +63,16 @@ class JobServiceGrpcAsyncIOTransport(JobServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -104,24 +98,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -156,10 +148,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -168,7 +160,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -176,70 +171,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -247,18 +222,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -287,11 +252,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_custom_job( - self, - ) -> Callable[ - [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] - ]: + def create_custom_job(self) -> Callable[ + [job_service.CreateCustomJobRequest], + Awaitable[gca_custom_job.CustomJob]]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -307,18 +270,18 @@ def create_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_custom_job" not in self._stubs: - self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", + if 'create_custom_job' not in self._stubs: + self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs["create_custom_job"] + return self._stubs['create_custom_job'] @property - def get_custom_job( - self, - ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: + def get_custom_job(self) -> Callable[ + [job_service.GetCustomJobRequest], + Awaitable[custom_job.CustomJob]]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -333,21 +296,18 @@ def get_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_custom_job" not in self._stubs: - self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", + if 'get_custom_job' not in self._stubs: + self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs["get_custom_job"] + return self._stubs['get_custom_job'] @property - def list_custom_jobs( - self, - ) -> Callable[ - [job_service.ListCustomJobsRequest], - Awaitable[job_service.ListCustomJobsResponse], - ]: + def list_custom_jobs(self) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse]]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -362,20 +322,18 @@ def list_custom_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_custom_jobs" not in self._stubs: - self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", + if 'list_custom_jobs' not in self._stubs: + self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs["list_custom_jobs"] + return self._stubs['list_custom_jobs'] @property - def delete_custom_job( - self, - ) -> Callable[ - [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] - ]: + def delete_custom_job(self) -> Callable[ + [job_service.DeleteCustomJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -390,18 +348,18 @@ def delete_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_custom_job" not in self._stubs: - self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", + if 'delete_custom_job' not in self._stubs: + self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_custom_job"] + return self._stubs['delete_custom_job'] @property - def cancel_custom_job( - self, - ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: + def cancel_custom_job(self) -> Callable[ + [job_service.CancelCustomJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -428,21 +386,18 @@ def cancel_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_custom_job" not in self._stubs: - self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", + if 'cancel_custom_job' not in self._stubs: + self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_custom_job"] + return self._stubs['cancel_custom_job'] @property - def create_data_labeling_job( - self, - ) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - Awaitable[gca_data_labeling_job.DataLabelingJob], - ]: + def create_data_labeling_job(self) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob]]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -457,21 +412,18 @@ def create_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_data_labeling_job" not in self._stubs: - self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", + if 'create_data_labeling_job' not in self._stubs: + self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["create_data_labeling_job"] + return self._stubs['create_data_labeling_job'] @property - def get_data_labeling_job( - self, - ) -> Callable[ - [job_service.GetDataLabelingJobRequest], - Awaitable[data_labeling_job.DataLabelingJob], - ]: + def get_data_labeling_job(self) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob]]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -486,21 +438,18 @@ def get_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_data_labeling_job" not in self._stubs: - self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", + if 'get_data_labeling_job' not in self._stubs: + self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["get_data_labeling_job"] + return self._stubs['get_data_labeling_job'] @property - def list_data_labeling_jobs( - self, - ) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - Awaitable[job_service.ListDataLabelingJobsResponse], - ]: + def list_data_labeling_jobs(self) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse]]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -515,20 +464,18 @@ def list_data_labeling_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_labeling_jobs" not in self._stubs: - self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", + if 'list_data_labeling_jobs' not in self._stubs: + self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs["list_data_labeling_jobs"] + return self._stubs['list_data_labeling_jobs'] @property - def delete_data_labeling_job( - self, - ) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] - ]: + def delete_data_labeling_job(self) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -543,18 +490,18 @@ def delete_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_data_labeling_job" not in self._stubs: - self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", + if 'delete_data_labeling_job' not in self._stubs: + self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_data_labeling_job"] + return self._stubs['delete_data_labeling_job'] @property - def cancel_data_labeling_job( - self, - ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: + def cancel_data_labeling_job(self) -> Callable[ + [job_service.CancelDataLabelingJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -570,21 +517,18 @@ def cancel_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_data_labeling_job" not in self._stubs: - self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", + if 'cancel_data_labeling_job' not in self._stubs: + self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_data_labeling_job"] + return self._stubs['cancel_data_labeling_job'] @property - def create_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], - ]: + def create_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob]]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -600,23 +544,18 @@ def create_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "create_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", + if 'create_hyperparameter_tuning_job' not in self._stubs: + self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["create_hyperparameter_tuning_job"] + return self._stubs['create_hyperparameter_tuning_job'] @property - def get_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], - ]: + def get_hyperparameter_tuning_job(self) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob]]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -631,23 +570,18 @@ def get_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "get_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", + if 'get_hyperparameter_tuning_job' not in self._stubs: + self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["get_hyperparameter_tuning_job"] + return self._stubs['get_hyperparameter_tuning_job'] @property - def list_hyperparameter_tuning_jobs( - self, - ) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - Awaitable[job_service.ListHyperparameterTuningJobsResponse], - ]: + def list_hyperparameter_tuning_jobs(self) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse]]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -663,23 +597,18 @@ def list_hyperparameter_tuning_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_hyperparameter_tuning_jobs" not in self._stubs: - self._stubs[ - "list_hyperparameter_tuning_jobs" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", + if 'list_hyperparameter_tuning_jobs' not in self._stubs: + self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs["list_hyperparameter_tuning_jobs"] + return self._stubs['list_hyperparameter_tuning_jobs'] @property - def delete_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - Awaitable[operations.Operation], - ]: + def delete_hyperparameter_tuning_job(self) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -695,22 +624,18 @@ def delete_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "delete_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", + if 'delete_hyperparameter_tuning_job' not in self._stubs: + self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_hyperparameter_tuning_job"] + return self._stubs['delete_hyperparameter_tuning_job'] @property - def cancel_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] - ]: + def cancel_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -739,23 +664,18 @@ def cancel_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "cancel_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", + if 'cancel_hyperparameter_tuning_job' not in self._stubs: + self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_hyperparameter_tuning_job"] + return self._stubs['cancel_hyperparameter_tuning_job'] @property - def create_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - Awaitable[gca_batch_prediction_job.BatchPredictionJob], - ]: + def create_batch_prediction_job(self) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob]]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -771,21 +691,18 @@ def create_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_batch_prediction_job" not in self._stubs: - self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", + if 'create_batch_prediction_job' not in self._stubs: + self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["create_batch_prediction_job"] + return self._stubs['create_batch_prediction_job'] @property - def get_batch_prediction_job( - self, - ) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - Awaitable[batch_prediction_job.BatchPredictionJob], - ]: + def get_batch_prediction_job(self) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob]]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -800,21 +717,18 @@ def get_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_batch_prediction_job" not in self._stubs: - self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", + if 'get_batch_prediction_job' not in self._stubs: + self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["get_batch_prediction_job"] + return self._stubs['get_batch_prediction_job'] @property - def list_batch_prediction_jobs( - self, - ) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - Awaitable[job_service.ListBatchPredictionJobsResponse], - ]: + def list_batch_prediction_jobs(self) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse]]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -829,20 +743,18 @@ def list_batch_prediction_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_batch_prediction_jobs" not in self._stubs: - self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", + if 'list_batch_prediction_jobs' not in self._stubs: + self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs["list_batch_prediction_jobs"] + return self._stubs['list_batch_prediction_jobs'] @property - def delete_batch_prediction_job( - self, - ) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] - ]: + def delete_batch_prediction_job(self) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -858,20 +770,18 @@ def delete_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_batch_prediction_job" not in self._stubs: - self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", + if 'delete_batch_prediction_job' not in self._stubs: + self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_batch_prediction_job"] + return self._stubs['delete_batch_prediction_job'] @property - def cancel_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] - ]: + def cancel_batch_prediction_job(self) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -897,13 +807,238 @@ def cancel_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_batch_prediction_job" not in self._stubs: - self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", + if 'cancel_batch_prediction_job' not in self._stubs: + self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_batch_prediction_job"] + return self._stubs['cancel_batch_prediction_job'] + + @property + def create_model_deployment_monitoring_job(self) -> Callable[ + [job_service.CreateModelDeploymentMonitoringJobRequest], + Awaitable[gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob]]: + r"""Return a callable for the create model deployment + monitoring job method over gRPC. + + Creates a ModelDeploymentMonitoringJob. It will run + periodically on a configured interval. + + Returns: + Callable[[~.CreateModelDeploymentMonitoringJobRequest], + Awaitable[~.ModelDeploymentMonitoringJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_model_deployment_monitoring_job' not in self._stubs: + self._stubs['create_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateModelDeploymentMonitoringJob', + request_serializer=job_service.CreateModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, + ) + return self._stubs['create_model_deployment_monitoring_job'] + @property + def search_model_deployment_monitoring_stats_anomalies(self) -> Callable[ + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + Awaitable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]]: + r"""Return a callable for the search model deployment + monitoring stats anomalies method over gRPC. + + Searches Model Monitoring Statistics generated within + a given time window. + + Returns: + Callable[[~.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + Awaitable[~.SearchModelDeploymentMonitoringStatsAnomaliesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'search_model_deployment_monitoring_stats_anomalies' not in self._stubs: + self._stubs['search_model_deployment_monitoring_stats_anomalies'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/SearchModelDeploymentMonitoringStatsAnomalies', + request_serializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest.serialize, + response_deserializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse.deserialize, + ) + return self._stubs['search_model_deployment_monitoring_stats_anomalies'] + + @property + def get_model_deployment_monitoring_job(self) -> Callable[ + [job_service.GetModelDeploymentMonitoringJobRequest], + Awaitable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]]: + r"""Return a callable for the get model deployment + monitoring job method over gRPC. + + Gets a ModelDeploymentMonitoringJob. + + Returns: + Callable[[~.GetModelDeploymentMonitoringJobRequest], + Awaitable[~.ModelDeploymentMonitoringJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model_deployment_monitoring_job' not in self._stubs: + self._stubs['get_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetModelDeploymentMonitoringJob', + request_serializer=job_service.GetModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, + ) + return self._stubs['get_model_deployment_monitoring_job'] + + @property + def list_model_deployment_monitoring_jobs(self) -> Callable[ + [job_service.ListModelDeploymentMonitoringJobsRequest], + Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse]]: + r"""Return a callable for the list model deployment + monitoring jobs method over gRPC. -__all__ = ("JobServiceGrpcAsyncIOTransport",) + Lists ModelDeploymentMonitoringJobs in a Location. + + Returns: + Callable[[~.ListModelDeploymentMonitoringJobsRequest], + Awaitable[~.ListModelDeploymentMonitoringJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_model_deployment_monitoring_jobs' not in self._stubs: + self._stubs['list_model_deployment_monitoring_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListModelDeploymentMonitoringJobs', + request_serializer=job_service.ListModelDeploymentMonitoringJobsRequest.serialize, + response_deserializer=job_service.ListModelDeploymentMonitoringJobsResponse.deserialize, + ) + return self._stubs['list_model_deployment_monitoring_jobs'] + + @property + def update_model_deployment_monitoring_job(self) -> Callable[ + [job_service.UpdateModelDeploymentMonitoringJobRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the update model deployment + monitoring job method over gRPC. + + Updates a ModelDeploymentMonitoringJob. + + Returns: + Callable[[~.UpdateModelDeploymentMonitoringJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_model_deployment_monitoring_job' not in self._stubs: + self._stubs['update_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/UpdateModelDeploymentMonitoringJob', + request_serializer=job_service.UpdateModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['update_model_deployment_monitoring_job'] + + @property + def delete_model_deployment_monitoring_job(self) -> Callable[ + [job_service.DeleteModelDeploymentMonitoringJobRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete model deployment + monitoring job method over gRPC. + + Deletes a ModelDeploymentMonitoringJob. + + Returns: + Callable[[~.DeleteModelDeploymentMonitoringJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_model_deployment_monitoring_job' not in self._stubs: + self._stubs['delete_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteModelDeploymentMonitoringJob', + request_serializer=job_service.DeleteModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_model_deployment_monitoring_job'] + + @property + def pause_model_deployment_monitoring_job(self) -> Callable[ + [job_service.PauseModelDeploymentMonitoringJobRequest], + Awaitable[empty.Empty]]: + r"""Return a callable for the pause model deployment + monitoring job method over gRPC. + + Pauses a ModelDeploymentMonitoringJob. If the job is running, + the server makes a best effort to cancel the job. Will mark + ``ModelDeploymentMonitoringJob.state`` + to 'PAUSED'. + + Returns: + Callable[[~.PauseModelDeploymentMonitoringJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'pause_model_deployment_monitoring_job' not in self._stubs: + self._stubs['pause_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/PauseModelDeploymentMonitoringJob', + request_serializer=job_service.PauseModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['pause_model_deployment_monitoring_job'] + + @property + def resume_model_deployment_monitoring_job(self) -> Callable[ + [job_service.ResumeModelDeploymentMonitoringJobRequest], + Awaitable[empty.Empty]]: + r"""Return a callable for the resume model deployment + monitoring job method over gRPC. + + Resumes a paused ModelDeploymentMonitoringJob. It + will start to run from next scheduled time. A deleted + ModelDeploymentMonitoringJob can't be resumed. + + Returns: + Callable[[~.ResumeModelDeploymentMonitoringJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'resume_model_deployment_monitoring_job' not in self._stubs: + self._stubs['resume_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ResumeModelDeploymentMonitoringJob', + request_serializer=job_service.ResumeModelDeploymentMonitoringJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['resume_model_deployment_monitoring_job'] + + +__all__ = ( + 'JobServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py new file mode 100644 index 0000000000..1f8cc4b7fb --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import MetadataServiceClient +from .async_client import MetadataServiceAsyncClient + +__all__ = ( + 'MetadataServiceClient', + 'MetadataServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py new file mode 100644 index 0000000000..d47a250882 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -0,0 +1,2487 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import context as gca_context +from google.cloud.aiplatform_v1beta1.types import encryption_spec +from google.cloud.aiplatform_v1beta1.types import event +from google.cloud.aiplatform_v1beta1.types import execution +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import lineage_subgraph +from google.cloud.aiplatform_v1beta1.types import metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_schema as gca_metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import metadata_store +from google.cloud.aiplatform_v1beta1.types import metadata_store as gca_metadata_store +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import MetadataServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import MetadataServiceGrpcAsyncIOTransport +from .client import MetadataServiceClient + + +class MetadataServiceAsyncClient: + """Service for reading and writing metadata entries.""" + + _client: MetadataServiceClient + + DEFAULT_ENDPOINT = MetadataServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = MetadataServiceClient.DEFAULT_MTLS_ENDPOINT + + artifact_path = staticmethod(MetadataServiceClient.artifact_path) + parse_artifact_path = staticmethod(MetadataServiceClient.parse_artifact_path) + context_path = staticmethod(MetadataServiceClient.context_path) + parse_context_path = staticmethod(MetadataServiceClient.parse_context_path) + execution_path = staticmethod(MetadataServiceClient.execution_path) + parse_execution_path = staticmethod(MetadataServiceClient.parse_execution_path) + metadata_schema_path = staticmethod(MetadataServiceClient.metadata_schema_path) + parse_metadata_schema_path = staticmethod(MetadataServiceClient.parse_metadata_schema_path) + metadata_store_path = staticmethod(MetadataServiceClient.metadata_store_path) + parse_metadata_store_path = staticmethod(MetadataServiceClient.parse_metadata_store_path) + + common_billing_account_path = staticmethod(MetadataServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(MetadataServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(MetadataServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(MetadataServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(MetadataServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(MetadataServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(MetadataServiceClient.common_project_path) + parse_common_project_path = staticmethod(MetadataServiceClient.parse_common_project_path) + + common_location_path = staticmethod(MetadataServiceClient.common_location_path) + parse_common_location_path = staticmethod(MetadataServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MetadataServiceAsyncClient: The constructed client. + """ + return MetadataServiceClient.from_service_account_info.__func__(MetadataServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MetadataServiceAsyncClient: The constructed client. + """ + return MetadataServiceClient.from_service_account_file.__func__(MetadataServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MetadataServiceTransport: + """Return the transport used by the client instance. + + Returns: + MetadataServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(MetadataServiceClient).get_transport_class, type(MetadataServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, MetadataServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the metadata service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.MetadataServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = MetadataServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_metadata_store(self, + request: metadata_service.CreateMetadataStoreRequest = None, + *, + parent: str = None, + metadata_store: gca_metadata_store.MetadataStore = None, + metadata_store_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Initializes a MetadataStore, including allocation of + resources. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateMetadataStoreRequest`): + The request object. Request message for + ``MetadataService.CreateMetadataStore``. + parent (:class:`str`): + Required. The resource name of the + Location where the MetadataStore should + be created. Format: + projects/{project}/locations/{location}/ + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_store (:class:`google.cloud.aiplatform_v1beta1.types.MetadataStore`): + Required. The MetadataStore to + create. + + This corresponds to the ``metadata_store`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_store_id (:class:`str`): + The {metadatastore} portion of the resource name with + the format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + If not provided, the MetadataStore's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all MetadataStores in the parent Location. + (Otherwise the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the + preexisting MetadataStore.) + + This corresponds to the ``metadata_store_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.MetadataStore` Instance of a metadata store. Contains a set of metadata that can be + queried. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, metadata_store, metadata_store_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.CreateMetadataStoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if metadata_store is not None: + request.metadata_store = metadata_store + if metadata_store_id is not None: + request.metadata_store_id = metadata_store_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_metadata_store, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_metadata_store.MetadataStore, + metadata_type=metadata_service.CreateMetadataStoreOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_metadata_store(self, + request: metadata_service.GetMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_store.MetadataStore: + r"""Retrieves a specific MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetMetadataStoreRequest`): + The request object. Request message for + ``MetadataService.GetMetadataStore``. + name (:class:`str`): + Required. The resource name of the + MetadataStore to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.MetadataStore: + Instance of a metadata store. + Contains a set of metadata that can be + queried. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.GetMetadataStoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_metadata_store, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_metadata_stores(self, + request: metadata_service.ListMetadataStoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataStoresAsyncPager: + r"""Lists MetadataStores for a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListMetadataStoresRequest`): + The request object. Request message for + ``MetadataService.ListMetadataStores``. + parent (:class:`str`): + Required. The Location whose + MetadataStores should be listed. Format: + projects/{project}/locations/{location} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListMetadataStoresAsyncPager: + Response message for + ``MetadataService.ListMetadataStores``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.ListMetadataStoresRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_metadata_stores, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListMetadataStoresAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_metadata_store(self, + request: metadata_service.DeleteMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a single MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteMetadataStoreRequest`): + The request object. Request message for + ``MetadataService.DeleteMetadataStore``. + name (:class:`str`): + Required. The resource name of the + MetadataStore to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.DeleteMetadataStoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_metadata_store, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=metadata_service.DeleteMetadataStoreOperationMetadata, + ) + + # Done; return the response. + return response + + async def create_artifact(self, + request: metadata_service.CreateArtifactRequest = None, + *, + parent: str = None, + artifact: gca_artifact.Artifact = None, + artifact_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: + r"""Creates an Artifact associated with a MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateArtifactRequest`): + The request object. Request message for + ``MetadataService.CreateArtifact``. + parent (:class:`str`): + Required. The resource name of the + MetadataStore where the Artifact should + be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + artifact (:class:`google.cloud.aiplatform_v1beta1.types.Artifact`): + Required. The Artifact to create. + This corresponds to the ``artifact`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + artifact_id (:class:`str`): + The {artifact} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + If not provided, the Artifact's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all Artifacts in the parent MetadataStore. + (Otherwise the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the + preexisting Artifact.) + + This corresponds to the ``artifact_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Artifact: + Instance of a general artifact. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, artifact, artifact_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.CreateArtifactRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if artifact is not None: + request.artifact = artifact + if artifact_id is not None: + request.artifact_id = artifact_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_artifact, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_artifact(self, + request: metadata_service.GetArtifactRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> artifact.Artifact: + r"""Retrieves a specific Artifact. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetArtifactRequest`): + The request object. Request message for + ``MetadataService.GetArtifact``. + name (:class:`str`): + Required. The resource name of the + Artifact to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Artifact: + Instance of a general artifact. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.GetArtifactRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_artifact, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_artifacts(self, + request: metadata_service.ListArtifactsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListArtifactsAsyncPager: + r"""Lists Artifacts in the MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListArtifactsRequest`): + The request object. Request message for + ``MetadataService.ListArtifacts``. + parent (:class:`str`): + Required. The MetadataStore whose + Artifacts should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListArtifactsAsyncPager: + Response message for + ``MetadataService.ListArtifacts``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.ListArtifactsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_artifacts, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListArtifactsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_artifact(self, + request: metadata_service.UpdateArtifactRequest = None, + *, + artifact: gca_artifact.Artifact = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: + r"""Updates a stored Artifact. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateArtifactRequest`): + The request object. Request message for + ``MetadataService.UpdateArtifact``. + artifact (:class:`google.cloud.aiplatform_v1beta1.types.Artifact`): + Required. The Artifact containing updates. The + Artifact's + ``Artifact.name`` + field is used to identify the Artifact to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + This corresponds to the ``artifact`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. A FieldMask indicating + which fields should be updated. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Artifact: + Instance of a general artifact. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([artifact, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.UpdateArtifactRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if artifact is not None: + request.artifact = artifact + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_artifact, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('artifact.name', request.artifact.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_context(self, + request: metadata_service.CreateContextRequest = None, + *, + parent: str = None, + context: gca_context.Context = None, + context_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: + r"""Creates a Context associated with a MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateContextRequest`): + The request object. Request message for + ``MetadataService.CreateContext``. + parent (:class:`str`): + Required. The resource name of the + MetadataStore where the Context should + be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + context (:class:`google.cloud.aiplatform_v1beta1.types.Context`): + Required. The Context to create. + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + context_id (:class:`str`): + The {context} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + If not provided, the Context's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all Contexts in the parent MetadataStore. + (Otherwise the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the + preexisting Context.) + + This corresponds to the ``context_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Context: + Instance of a general context. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, context, context_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.CreateContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if context is not None: + request.context = context + if context_id is not None: + request.context_id = context_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_context, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_context(self, + request: metadata_service.GetContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> context.Context: + r"""Retrieves a specific Context. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetContextRequest`): + The request object. Request message for + ``MetadataService.GetContext``. + name (:class:`str`): + Required. The resource name of the + Context to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Context: + Instance of a general context. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.GetContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_context, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_contexts(self, + request: metadata_service.ListContextsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListContextsAsyncPager: + r"""Lists Contexts on the MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListContextsRequest`): + The request object. Request message for + ``MetadataService.ListContexts`` + parent (:class:`str`): + Required. The MetadataStore whose + Contexts should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListContextsAsyncPager: + Response message for + ``MetadataService.ListContexts``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.ListContextsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_contexts, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListContextsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_context(self, + request: metadata_service.UpdateContextRequest = None, + *, + context: gca_context.Context = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: + r"""Updates a stored Context. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateContextRequest`): + The request object. Request message for + ``MetadataService.UpdateContext``. + context (:class:`google.cloud.aiplatform_v1beta1.types.Context`): + Required. The Context containing updates. The Context's + ``Context.name`` + field is used to identify the Context to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. A FieldMask indicating + which fields should be updated. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Context: + Instance of a general context. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.UpdateContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_context, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context.name', request.context.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_context(self, + request: metadata_service.DeleteContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a stored Context. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteContextRequest`): + The request object. Request message for + ``MetadataService.DeleteContext``. + name (:class:`str`): + Required. The resource name of the + Context to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.DeleteContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_context, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def add_context_artifacts_and_executions(self, + request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, + *, + context: str = None, + artifacts: Sequence[str] = None, + executions: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: + r"""Adds a set of Artifacts and Executions to a Context. + If any of the Artifacts or Executions have already been + added to a Context, they are simply skipped. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.AddContextArtifactsAndExecutionsRequest`): + The request object. Request message for + ``MetadataService.AddContextArtifactsAndExecutions``. + context (:class:`str`): + Required. The resource name of the + Context that the Artifacts and + Executions belong to. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + artifacts (:class:`Sequence[str]`): + The resource names of the Artifacts + to attribute to the Context. + + This corresponds to the ``artifacts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + executions (:class:`Sequence[str]`): + The resource names of the Executions + to associate with the Context. + + This corresponds to the ``executions`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.AddContextArtifactsAndExecutionsResponse: + Response message for + ``MetadataService.AddContextArtifactsAndExecutions``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context, artifacts, executions]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + + if artifacts: + request.artifacts.extend(artifacts) + if executions: + request.executions.extend(executions) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.add_context_artifacts_and_executions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context', request.context), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def add_context_children(self, + request: metadata_service.AddContextChildrenRequest = None, + *, + context: str = None, + child_contexts: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextChildrenResponse: + r"""Adds a set of Contexts as children to a parent Context. If any + of the child Contexts have already been added to the parent + Context, they are simply skipped. If this call would create a + cycle or cause any Context to have more than 10 parents, the + request will fail with INVALID_ARGUMENT error. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.AddContextChildrenRequest`): + The request object. Request message for + ``MetadataService.AddContextChildren``. + context (:class:`str`): + Required. The resource name of the + parent Context. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + child_contexts (:class:`Sequence[str]`): + The resource names of the child + Contexts. + + This corresponds to the ``child_contexts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.AddContextChildrenResponse: + Response message for + ``MetadataService.AddContextChildren``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context, child_contexts]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.AddContextChildrenRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + + if child_contexts: + request.child_contexts.extend(child_contexts) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.add_context_children, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context', request.context), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def query_context_lineage_subgraph(self, + request: metadata_service.QueryContextLineageSubgraphRequest = None, + *, + context: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: + r"""Retrieves Artifacts and Executions within the + specified Context, connected by Event edges and returned + as a LineageSubgraph. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.QueryContextLineageSubgraphRequest`): + The request object. Request message for + ``MetadataService.QueryContextLineageSubgraph``. + context (:class:`str`): + Required. The resource name of the Context whose + Artifacts and Executions should be retrieved as a + LineageSubgraph. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + The request may error with FAILED_PRECONDITION if the + number of Artifacts, the number of Executions, or the + number of Events that would be returned for the Context + exceeds 1000. + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.LineageSubgraph: + A subgraph of the overall lineage + graph. Event edges connect Artifact and + Execution nodes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.QueryContextLineageSubgraphRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.query_context_lineage_subgraph, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context', request.context), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_execution(self, + request: metadata_service.CreateExecutionRequest = None, + *, + parent: str = None, + execution: gca_execution.Execution = None, + execution_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: + r"""Creates an Execution associated with a MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateExecutionRequest`): + The request object. Request message for + ``MetadataService.CreateExecution``. + parent (:class:`str`): + Required. The resource name of the + MetadataStore where the Execution should + be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + execution (:class:`google.cloud.aiplatform_v1beta1.types.Execution`): + Required. The Execution to create. + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + execution_id (:class:`str`): + The {execution} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + If not provided, the Execution's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all Executions in the parent + MetadataStore. (Otherwise the request will fail with + ALREADY_EXISTS, or PERMISSION_DENIED if the caller can't + view the preexisting Execution.) + + This corresponds to the ``execution_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Execution: + Instance of a general execution. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, execution, execution_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.CreateExecutionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if execution is not None: + request.execution = execution + if execution_id is not None: + request.execution_id = execution_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_execution, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_execution(self, + request: metadata_service.GetExecutionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> execution.Execution: + r"""Retrieves a specific Execution. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetExecutionRequest`): + The request object. Request message for + ``MetadataService.GetExecution``. + name (:class:`str`): + Required. The resource name of the + Execution to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Execution: + Instance of a general execution. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.GetExecutionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_execution, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_executions(self, + request: metadata_service.ListExecutionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListExecutionsAsyncPager: + r"""Lists Executions in the MetadataStore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListExecutionsRequest`): + The request object. Request message for + ``MetadataService.ListExecutions``. + parent (:class:`str`): + Required. The MetadataStore whose + Executions should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListExecutionsAsyncPager: + Response message for + ``MetadataService.ListExecutions``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.ListExecutionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_executions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListExecutionsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_execution(self, + request: metadata_service.UpdateExecutionRequest = None, + *, + execution: gca_execution.Execution = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: + r"""Updates a stored Execution. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateExecutionRequest`): + The request object. Request message for + ``MetadataService.UpdateExecution``. + execution (:class:`google.cloud.aiplatform_v1beta1.types.Execution`): + Required. The Execution containing updates. The + Execution's + ``Execution.name`` + field is used to identify the Execution to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. A FieldMask indicating + which fields should be updated. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Execution: + Instance of a general execution. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([execution, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.UpdateExecutionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if execution is not None: + request.execution = execution + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_execution, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('execution.name', request.execution.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def add_execution_events(self, + request: metadata_service.AddExecutionEventsRequest = None, + *, + execution: str = None, + events: Sequence[event.Event] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddExecutionEventsResponse: + r"""Adds Events for denoting whether each Artifact was an + input or output for a given Execution. If any Events + already exist between the Execution and any of the + specified Artifacts they are simply skipped. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.AddExecutionEventsRequest`): + The request object. Request message for + ``MetadataService.AddExecutionEvents``. + execution (:class:`str`): + Required. The resource name of the + Execution that the Events connect + Artifacts with. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + events (:class:`Sequence[google.cloud.aiplatform_v1beta1.types.Event]`): + The Events to create and add. + This corresponds to the ``events`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.AddExecutionEventsResponse: + Response message for + ``MetadataService.AddExecutionEvents``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([execution, events]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.AddExecutionEventsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if execution is not None: + request.execution = execution + + if events: + request.events.extend(events) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.add_execution_events, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('execution', request.execution), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def query_execution_inputs_and_outputs(self, + request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, + *, + execution: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: + r"""Obtains the set of input and output Artifacts for + this Execution, in the form of LineageSubgraph that also + contains the Execution and connecting Events. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.QueryExecutionInputsAndOutputsRequest`): + The request object. Request message for + ``MetadataService.QueryExecutionInputsAndOutputs``. + execution (:class:`str`): + Required. The resource name of the + Execution whose input and output + Artifacts should be retrieved as a + LineageSubgraph. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.LineageSubgraph: + A subgraph of the overall lineage + graph. Event edges connect Artifact and + Execution nodes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([execution]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if execution is not None: + request.execution = execution + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.query_execution_inputs_and_outputs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('execution', request.execution), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_metadata_schema(self, + request: metadata_service.CreateMetadataSchemaRequest = None, + *, + parent: str = None, + metadata_schema: gca_metadata_schema.MetadataSchema = None, + metadata_schema_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_metadata_schema.MetadataSchema: + r"""Creates an MetadataSchema. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateMetadataSchemaRequest`): + The request object. Request message for + ``MetadataService.CreateMetadataSchema``. + parent (:class:`str`): + Required. The resource name of the + MetadataStore where the MetadataSchema + should be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_schema (:class:`google.cloud.aiplatform_v1beta1.types.MetadataSchema`): + Required. The MetadataSchema to + create. + + This corresponds to the ``metadata_schema`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_schema_id (:class:`str`): + The {metadata_schema} portion of the resource name with + the format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/metadataSchemas/{metadataschema} + If not provided, the MetadataStore's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all MetadataSchemas in the parent + Location. (Otherwise the request will fail with + ALREADY_EXISTS, or PERMISSION_DENIED if the caller can't + view the preexisting MetadataSchema.) + + This corresponds to the ``metadata_schema_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.MetadataSchema: + Instance of a general MetadataSchema. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.CreateMetadataSchemaRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if metadata_schema is not None: + request.metadata_schema = metadata_schema + if metadata_schema_id is not None: + request.metadata_schema_id = metadata_schema_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_metadata_schema, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_metadata_schema(self, + request: metadata_service.GetMetadataSchemaRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_schema.MetadataSchema: + r"""Retrieves a specific MetadataSchema. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetMetadataSchemaRequest`): + The request object. Request message for + ``MetadataService.GetMetadataSchema``. + name (:class:`str`): + Required. The resource name of the + MetadataSchema to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/metadataSchemas/{metadataschema} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.MetadataSchema: + Instance of a general MetadataSchema. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.GetMetadataSchemaRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_metadata_schema, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_metadata_schemas(self, + request: metadata_service.ListMetadataSchemasRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataSchemasAsyncPager: + r"""Lists MetadataSchemas. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasRequest`): + The request object. Request message for + ``MetadataService.ListMetadataSchemas``. + parent (:class:`str`): + Required. The MetadataStore whose + MetadataSchemas should be listed. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListMetadataSchemasAsyncPager: + Response message for + ``MetadataService.ListMetadataSchemas``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.ListMetadataSchemasRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_metadata_schemas, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListMetadataSchemasAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'MetadataServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py new file mode 100644 index 0000000000..e1fcc67567 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -0,0 +1,2717 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import context as gca_context +from google.cloud.aiplatform_v1beta1.types import encryption_spec +from google.cloud.aiplatform_v1beta1.types import event +from google.cloud.aiplatform_v1beta1.types import execution +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import lineage_subgraph +from google.cloud.aiplatform_v1beta1.types import metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_schema as gca_metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import metadata_store +from google.cloud.aiplatform_v1beta1.types import metadata_store as gca_metadata_store +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import MetadataServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import MetadataServiceGrpcTransport +from .transports.grpc_asyncio import MetadataServiceGrpcAsyncIOTransport + + +class MetadataServiceClientMeta(type): + """Metaclass for the MetadataService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[MetadataServiceTransport]] + _transport_registry['grpc'] = MetadataServiceGrpcTransport + _transport_registry['grpc_asyncio'] = MetadataServiceGrpcAsyncIOTransport + + def get_transport_class(cls, + label: str = None, + ) -> Type[MetadataServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class MetadataServiceClient(metaclass=MetadataServiceClientMeta): + """Service for reading and writing metadata entries.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MetadataServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MetadataServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MetadataServiceTransport: + """Return the transport used by the client instance. + + Returns: + MetadataServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def artifact_path(project: str,location: str,metadata_store: str,artifact: str,) -> str: + """Return a fully-qualified artifact string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format(project=project, location=location, metadata_store=metadata_store, artifact=artifact, ) + + @staticmethod + def parse_artifact_path(path: str) -> Dict[str,str]: + """Parse a artifact path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/artifacts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def context_path(project: str,location: str,metadata_store: str,context: str,) -> str: + """Return a fully-qualified context string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format(project=project, location=location, metadata_store=metadata_store, context=context, ) + + @staticmethod + def parse_context_path(path: str) -> Dict[str,str]: + """Parse a context path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def execution_path(project: str,location: str,metadata_store: str,execution: str,) -> str: + """Return a fully-qualified execution string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format(project=project, location=location, metadata_store=metadata_store, execution=execution, ) + + @staticmethod + def parse_execution_path(path: str) -> Dict[str,str]: + """Parse a execution path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/executions/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def metadata_schema_path(project: str,location: str,metadata_store: str,metadata_schema: str,) -> str: + """Return a fully-qualified metadata_schema string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format(project=project, location=location, metadata_store=metadata_store, metadata_schema=metadata_schema, ) + + @staticmethod + def parse_metadata_schema_path(path: str) -> Dict[str,str]: + """Parse a metadata_schema path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/metadataSchemas/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def metadata_store_path(project: str,location: str,metadata_store: str,) -> str: + """Return a fully-qualified metadata_store string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format(project=project, location=location, metadata_store=metadata_store, ) + + @staticmethod + def parse_metadata_store_path(path: str) -> Dict[str,str]: + """Parse a metadata_store path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MetadataServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the metadata service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, MetadataServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, MetadataServiceTransport): + # transport is a MetadataServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_metadata_store(self, + request: metadata_service.CreateMetadataStoreRequest = None, + *, + parent: str = None, + metadata_store: gca_metadata_store.MetadataStore = None, + metadata_store_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Initializes a MetadataStore, including allocation of + resources. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateMetadataStoreRequest): + The request object. Request message for + ``MetadataService.CreateMetadataStore``. + parent (str): + Required. The resource name of the + Location where the MetadataStore should + be created. Format: + projects/{project}/locations/{location}/ + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_store (google.cloud.aiplatform_v1beta1.types.MetadataStore): + Required. The MetadataStore to + create. + + This corresponds to the ``metadata_store`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_store_id (str): + The {metadatastore} portion of the resource name with + the format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + If not provided, the MetadataStore's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all MetadataStores in the parent Location. + (Otherwise the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the + preexisting MetadataStore.) + + This corresponds to the ``metadata_store_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.MetadataStore` Instance of a metadata store. Contains a set of metadata that can be + queried. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, metadata_store, metadata_store_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.CreateMetadataStoreRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.CreateMetadataStoreRequest): + request = metadata_service.CreateMetadataStoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if metadata_store is not None: + request.metadata_store = metadata_store + if metadata_store_id is not None: + request.metadata_store_id = metadata_store_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_metadata_store] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + gca_metadata_store.MetadataStore, + metadata_type=metadata_service.CreateMetadataStoreOperationMetadata, + ) + + # Done; return the response. + return response + + def get_metadata_store(self, + request: metadata_service.GetMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_store.MetadataStore: + r"""Retrieves a specific MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetMetadataStoreRequest): + The request object. Request message for + ``MetadataService.GetMetadataStore``. + name (str): + Required. The resource name of the + MetadataStore to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.MetadataStore: + Instance of a metadata store. + Contains a set of metadata that can be + queried. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.GetMetadataStoreRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.GetMetadataStoreRequest): + request = metadata_service.GetMetadataStoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_metadata_store] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_metadata_stores(self, + request: metadata_service.ListMetadataStoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataStoresPager: + r"""Lists MetadataStores for a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListMetadataStoresRequest): + The request object. Request message for + ``MetadataService.ListMetadataStores``. + parent (str): + Required. The Location whose + MetadataStores should be listed. Format: + projects/{project}/locations/{location} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListMetadataStoresPager: + Response message for + ``MetadataService.ListMetadataStores``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.ListMetadataStoresRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.ListMetadataStoresRequest): + request = metadata_service.ListMetadataStoresRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_metadata_stores] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListMetadataStoresPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_metadata_store(self, + request: metadata_service.DeleteMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a single MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteMetadataStoreRequest): + The request object. Request message for + ``MetadataService.DeleteMetadataStore``. + name (str): + Required. The resource name of the + MetadataStore to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.DeleteMetadataStoreRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.DeleteMetadataStoreRequest): + request = metadata_service.DeleteMetadataStoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_metadata_store] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=metadata_service.DeleteMetadataStoreOperationMetadata, + ) + + # Done; return the response. + return response + + def create_artifact(self, + request: metadata_service.CreateArtifactRequest = None, + *, + parent: str = None, + artifact: gca_artifact.Artifact = None, + artifact_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: + r"""Creates an Artifact associated with a MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateArtifactRequest): + The request object. Request message for + ``MetadataService.CreateArtifact``. + parent (str): + Required. The resource name of the + MetadataStore where the Artifact should + be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + artifact (google.cloud.aiplatform_v1beta1.types.Artifact): + Required. The Artifact to create. + This corresponds to the ``artifact`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + artifact_id (str): + The {artifact} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + If not provided, the Artifact's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all Artifacts in the parent MetadataStore. + (Otherwise the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the + preexisting Artifact.) + + This corresponds to the ``artifact_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Artifact: + Instance of a general artifact. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, artifact, artifact_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.CreateArtifactRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.CreateArtifactRequest): + request = metadata_service.CreateArtifactRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if artifact is not None: + request.artifact = artifact + if artifact_id is not None: + request.artifact_id = artifact_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_artifact] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_artifact(self, + request: metadata_service.GetArtifactRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> artifact.Artifact: + r"""Retrieves a specific Artifact. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetArtifactRequest): + The request object. Request message for + ``MetadataService.GetArtifact``. + name (str): + Required. The resource name of the + Artifact to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Artifact: + Instance of a general artifact. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.GetArtifactRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.GetArtifactRequest): + request = metadata_service.GetArtifactRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_artifact] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_artifacts(self, + request: metadata_service.ListArtifactsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListArtifactsPager: + r"""Lists Artifacts in the MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListArtifactsRequest): + The request object. Request message for + ``MetadataService.ListArtifacts``. + parent (str): + Required. The MetadataStore whose + Artifacts should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListArtifactsPager: + Response message for + ``MetadataService.ListArtifacts``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.ListArtifactsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.ListArtifactsRequest): + request = metadata_service.ListArtifactsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_artifacts] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListArtifactsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_artifact(self, + request: metadata_service.UpdateArtifactRequest = None, + *, + artifact: gca_artifact.Artifact = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: + r"""Updates a stored Artifact. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateArtifactRequest): + The request object. Request message for + ``MetadataService.UpdateArtifact``. + artifact (google.cloud.aiplatform_v1beta1.types.Artifact): + Required. The Artifact containing updates. The + Artifact's + ``Artifact.name`` + field is used to identify the Artifact to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + This corresponds to the ``artifact`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. A FieldMask indicating + which fields should be updated. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Artifact: + Instance of a general artifact. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([artifact, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.UpdateArtifactRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.UpdateArtifactRequest): + request = metadata_service.UpdateArtifactRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if artifact is not None: + request.artifact = artifact + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_artifact] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('artifact.name', request.artifact.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_context(self, + request: metadata_service.CreateContextRequest = None, + *, + parent: str = None, + context: gca_context.Context = None, + context_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: + r"""Creates a Context associated with a MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateContextRequest): + The request object. Request message for + ``MetadataService.CreateContext``. + parent (str): + Required. The resource name of the + MetadataStore where the Context should + be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + context (google.cloud.aiplatform_v1beta1.types.Context): + Required. The Context to create. + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + context_id (str): + The {context} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + If not provided, the Context's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all Contexts in the parent MetadataStore. + (Otherwise the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the + preexisting Context.) + + This corresponds to the ``context_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Context: + Instance of a general context. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, context, context_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.CreateContextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.CreateContextRequest): + request = metadata_service.CreateContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if context is not None: + request.context = context + if context_id is not None: + request.context_id = context_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_context] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_context(self, + request: metadata_service.GetContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> context.Context: + r"""Retrieves a specific Context. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetContextRequest): + The request object. Request message for + ``MetadataService.GetContext``. + name (str): + Required. The resource name of the + Context to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Context: + Instance of a general context. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.GetContextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.GetContextRequest): + request = metadata_service.GetContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_context] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_contexts(self, + request: metadata_service.ListContextsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListContextsPager: + r"""Lists Contexts on the MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListContextsRequest): + The request object. Request message for + ``MetadataService.ListContexts`` + parent (str): + Required. The MetadataStore whose + Contexts should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListContextsPager: + Response message for + ``MetadataService.ListContexts``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.ListContextsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.ListContextsRequest): + request = metadata_service.ListContextsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_contexts] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListContextsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_context(self, + request: metadata_service.UpdateContextRequest = None, + *, + context: gca_context.Context = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: + r"""Updates a stored Context. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateContextRequest): + The request object. Request message for + ``MetadataService.UpdateContext``. + context (google.cloud.aiplatform_v1beta1.types.Context): + Required. The Context containing updates. The Context's + ``Context.name`` + field is used to identify the Context to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. A FieldMask indicating + which fields should be updated. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Context: + Instance of a general context. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.UpdateContextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.UpdateContextRequest): + request = metadata_service.UpdateContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_context] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context.name', request.context.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_context(self, + request: metadata_service.DeleteContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a stored Context. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteContextRequest): + The request object. Request message for + ``MetadataService.DeleteContext``. + name (str): + Required. The resource name of the + Context to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.DeleteContextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.DeleteContextRequest): + request = metadata_service.DeleteContextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_context] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def add_context_artifacts_and_executions(self, + request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, + *, + context: str = None, + artifacts: Sequence[str] = None, + executions: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: + r"""Adds a set of Artifacts and Executions to a Context. + If any of the Artifacts or Executions have already been + added to a Context, they are simply skipped. + + Args: + request (google.cloud.aiplatform_v1beta1.types.AddContextArtifactsAndExecutionsRequest): + The request object. Request message for + ``MetadataService.AddContextArtifactsAndExecutions``. + context (str): + Required. The resource name of the + Context that the Artifacts and + Executions belong to. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + artifacts (Sequence[str]): + The resource names of the Artifacts + to attribute to the Context. + + This corresponds to the ``artifacts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + executions (Sequence[str]): + The resource names of the Executions + to associate with the Context. + + This corresponds to the ``executions`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.AddContextArtifactsAndExecutionsResponse: + Response message for + ``MetadataService.AddContextArtifactsAndExecutions``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context, artifacts, executions]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.AddContextArtifactsAndExecutionsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.AddContextArtifactsAndExecutionsRequest): + request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + if artifacts is not None: + request.artifacts = artifacts + if executions is not None: + request.executions = executions + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.add_context_artifacts_and_executions] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context', request.context), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def add_context_children(self, + request: metadata_service.AddContextChildrenRequest = None, + *, + context: str = None, + child_contexts: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextChildrenResponse: + r"""Adds a set of Contexts as children to a parent Context. If any + of the child Contexts have already been added to the parent + Context, they are simply skipped. If this call would create a + cycle or cause any Context to have more than 10 parents, the + request will fail with INVALID_ARGUMENT error. + + Args: + request (google.cloud.aiplatform_v1beta1.types.AddContextChildrenRequest): + The request object. Request message for + ``MetadataService.AddContextChildren``. + context (str): + Required. The resource name of the + parent Context. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + child_contexts (Sequence[str]): + The resource names of the child + Contexts. + + This corresponds to the ``child_contexts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.AddContextChildrenResponse: + Response message for + ``MetadataService.AddContextChildren``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context, child_contexts]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.AddContextChildrenRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.AddContextChildrenRequest): + request = metadata_service.AddContextChildrenRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + if child_contexts is not None: + request.child_contexts = child_contexts + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.add_context_children] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context', request.context), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def query_context_lineage_subgraph(self, + request: metadata_service.QueryContextLineageSubgraphRequest = None, + *, + context: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: + r"""Retrieves Artifacts and Executions within the + specified Context, connected by Event edges and returned + as a LineageSubgraph. + + Args: + request (google.cloud.aiplatform_v1beta1.types.QueryContextLineageSubgraphRequest): + The request object. Request message for + ``MetadataService.QueryContextLineageSubgraph``. + context (str): + Required. The resource name of the Context whose + Artifacts and Executions should be retrieved as a + LineageSubgraph. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + The request may error with FAILED_PRECONDITION if the + number of Artifacts, the number of Executions, or the + number of Events that would be returned for the Context + exceeds 1000. + + This corresponds to the ``context`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.LineageSubgraph: + A subgraph of the overall lineage + graph. Event edges connect Artifact and + Execution nodes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([context]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.QueryContextLineageSubgraphRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.QueryContextLineageSubgraphRequest): + request = metadata_service.QueryContextLineageSubgraphRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if context is not None: + request.context = context + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.query_context_lineage_subgraph] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('context', request.context), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_execution(self, + request: metadata_service.CreateExecutionRequest = None, + *, + parent: str = None, + execution: gca_execution.Execution = None, + execution_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: + r"""Creates an Execution associated with a MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateExecutionRequest): + The request object. Request message for + ``MetadataService.CreateExecution``. + parent (str): + Required. The resource name of the + MetadataStore where the Execution should + be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + execution (google.cloud.aiplatform_v1beta1.types.Execution): + Required. The Execution to create. + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + execution_id (str): + The {execution} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + If not provided, the Execution's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all Executions in the parent + MetadataStore. (Otherwise the request will fail with + ALREADY_EXISTS, or PERMISSION_DENIED if the caller can't + view the preexisting Execution.) + + This corresponds to the ``execution_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Execution: + Instance of a general execution. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, execution, execution_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.CreateExecutionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.CreateExecutionRequest): + request = metadata_service.CreateExecutionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if execution is not None: + request.execution = execution + if execution_id is not None: + request.execution_id = execution_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_execution] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_execution(self, + request: metadata_service.GetExecutionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> execution.Execution: + r"""Retrieves a specific Execution. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetExecutionRequest): + The request object. Request message for + ``MetadataService.GetExecution``. + name (str): + Required. The resource name of the + Execution to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Execution: + Instance of a general execution. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.GetExecutionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.GetExecutionRequest): + request = metadata_service.GetExecutionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_execution] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_executions(self, + request: metadata_service.ListExecutionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListExecutionsPager: + r"""Lists Executions in the MetadataStore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListExecutionsRequest): + The request object. Request message for + ``MetadataService.ListExecutions``. + parent (str): + Required. The MetadataStore whose + Executions should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListExecutionsPager: + Response message for + ``MetadataService.ListExecutions``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.ListExecutionsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.ListExecutionsRequest): + request = metadata_service.ListExecutionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_executions] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListExecutionsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_execution(self, + request: metadata_service.UpdateExecutionRequest = None, + *, + execution: gca_execution.Execution = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: + r"""Updates a stored Execution. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateExecutionRequest): + The request object. Request message for + ``MetadataService.UpdateExecution``. + execution (google.cloud.aiplatform_v1beta1.types.Execution): + Required. The Execution containing updates. The + Execution's + ``Execution.name`` + field is used to identify the Execution to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. A FieldMask indicating + which fields should be updated. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Execution: + Instance of a general execution. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([execution, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.UpdateExecutionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.UpdateExecutionRequest): + request = metadata_service.UpdateExecutionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if execution is not None: + request.execution = execution + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_execution] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('execution.name', request.execution.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def add_execution_events(self, + request: metadata_service.AddExecutionEventsRequest = None, + *, + execution: str = None, + events: Sequence[event.Event] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddExecutionEventsResponse: + r"""Adds Events for denoting whether each Artifact was an + input or output for a given Execution. If any Events + already exist between the Execution and any of the + specified Artifacts they are simply skipped. + + Args: + request (google.cloud.aiplatform_v1beta1.types.AddExecutionEventsRequest): + The request object. Request message for + ``MetadataService.AddExecutionEvents``. + execution (str): + Required. The resource name of the + Execution that the Events connect + Artifacts with. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + events (Sequence[google.cloud.aiplatform_v1beta1.types.Event]): + The Events to create and add. + This corresponds to the ``events`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.AddExecutionEventsResponse: + Response message for + ``MetadataService.AddExecutionEvents``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([execution, events]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.AddExecutionEventsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.AddExecutionEventsRequest): + request = metadata_service.AddExecutionEventsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if execution is not None: + request.execution = execution + if events is not None: + request.events = events + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.add_execution_events] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('execution', request.execution), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def query_execution_inputs_and_outputs(self, + request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, + *, + execution: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: + r"""Obtains the set of input and output Artifacts for + this Execution, in the form of LineageSubgraph that also + contains the Execution and connecting Events. + + Args: + request (google.cloud.aiplatform_v1beta1.types.QueryExecutionInputsAndOutputsRequest): + The request object. Request message for + ``MetadataService.QueryExecutionInputsAndOutputs``. + execution (str): + Required. The resource name of the + Execution whose input and output + Artifacts should be retrieved as a + LineageSubgraph. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``execution`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.LineageSubgraph: + A subgraph of the overall lineage + graph. Event edges connect Artifact and + Execution nodes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([execution]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.QueryExecutionInputsAndOutputsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.QueryExecutionInputsAndOutputsRequest): + request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if execution is not None: + request.execution = execution + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.query_execution_inputs_and_outputs] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('execution', request.execution), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_metadata_schema(self, + request: metadata_service.CreateMetadataSchemaRequest = None, + *, + parent: str = None, + metadata_schema: gca_metadata_schema.MetadataSchema = None, + metadata_schema_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_metadata_schema.MetadataSchema: + r"""Creates an MetadataSchema. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateMetadataSchemaRequest): + The request object. Request message for + ``MetadataService.CreateMetadataSchema``. + parent (str): + Required. The resource name of the + MetadataStore where the MetadataSchema + should be created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_schema (google.cloud.aiplatform_v1beta1.types.MetadataSchema): + Required. The MetadataSchema to + create. + + This corresponds to the ``metadata_schema`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + metadata_schema_id (str): + The {metadata_schema} portion of the resource name with + the format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/metadataSchemas/{metadataschema} + If not provided, the MetadataStore's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be + unique across all MetadataSchemas in the parent + Location. (Otherwise the request will fail with + ALREADY_EXISTS, or PERMISSION_DENIED if the caller can't + view the preexisting MetadataSchema.) + + This corresponds to the ``metadata_schema_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.MetadataSchema: + Instance of a general MetadataSchema. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.CreateMetadataSchemaRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.CreateMetadataSchemaRequest): + request = metadata_service.CreateMetadataSchemaRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if metadata_schema is not None: + request.metadata_schema = metadata_schema + if metadata_schema_id is not None: + request.metadata_schema_id = metadata_schema_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_metadata_schema] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_metadata_schema(self, + request: metadata_service.GetMetadataSchemaRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_schema.MetadataSchema: + r"""Retrieves a specific MetadataSchema. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetMetadataSchemaRequest): + The request object. Request message for + ``MetadataService.GetMetadataSchema``. + name (str): + Required. The resource name of the + MetadataSchema to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/metadataSchemas/{metadataschema} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.MetadataSchema: + Instance of a general MetadataSchema. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.GetMetadataSchemaRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.GetMetadataSchemaRequest): + request = metadata_service.GetMetadataSchemaRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_metadata_schema] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_metadata_schemas(self, + request: metadata_service.ListMetadataSchemasRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataSchemasPager: + r"""Lists MetadataSchemas. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasRequest): + The request object. Request message for + ``MetadataService.ListMetadataSchemas``. + parent (str): + Required. The MetadataStore whose + MetadataSchemas should be listed. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.metadata_service.pagers.ListMetadataSchemasPager: + Response message for + ``MetadataService.ListMetadataSchemas``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.ListMetadataSchemasRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.ListMetadataSchemasRequest): + request = metadata_service.ListMetadataSchemasRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_metadata_schemas] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListMetadataSchemasPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'MetadataServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py new file mode 100644 index 0000000000..da04d2882f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py @@ -0,0 +1,635 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional + +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import execution +from google.cloud.aiplatform_v1beta1.types import metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import metadata_store + + +class ListMetadataStoresPager: + """A pager for iterating through ``list_metadata_stores`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse` object, and + provides an ``__iter__`` method to iterate through its + ``metadata_stores`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListMetadataStores`` requests and continue to iterate + through the ``metadata_stores`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., metadata_service.ListMetadataStoresResponse], + request: metadata_service.ListMetadataStoresRequest, + response: metadata_service.ListMetadataStoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListMetadataStoresRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListMetadataStoresRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[metadata_service.ListMetadataStoresResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[metadata_store.MetadataStore]: + for page in self.pages: + yield from page.metadata_stores + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListMetadataStoresAsyncPager: + """A pager for iterating through ``list_metadata_stores`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``metadata_stores`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListMetadataStores`` requests and continue to iterate + through the ``metadata_stores`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[metadata_service.ListMetadataStoresResponse]], + request: metadata_service.ListMetadataStoresRequest, + response: metadata_service.ListMetadataStoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListMetadataStoresRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListMetadataStoresRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[metadata_service.ListMetadataStoresResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[metadata_store.MetadataStore]: + async def async_generator(): + async for page in self.pages: + for response in page.metadata_stores: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListArtifactsPager: + """A pager for iterating through ``list_artifacts`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``artifacts`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListArtifacts`` requests and continue to iterate + through the ``artifacts`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., metadata_service.ListArtifactsResponse], + request: metadata_service.ListArtifactsRequest, + response: metadata_service.ListArtifactsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListArtifactsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListArtifactsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[metadata_service.ListArtifactsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[artifact.Artifact]: + for page in self.pages: + yield from page.artifacts + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListArtifactsAsyncPager: + """A pager for iterating through ``list_artifacts`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``artifacts`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListArtifacts`` requests and continue to iterate + through the ``artifacts`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[metadata_service.ListArtifactsResponse]], + request: metadata_service.ListArtifactsRequest, + response: metadata_service.ListArtifactsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListArtifactsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListArtifactsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[metadata_service.ListArtifactsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[artifact.Artifact]: + async def async_generator(): + async for page in self.pages: + for response in page.artifacts: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListContextsPager: + """A pager for iterating through ``list_contexts`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListContextsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``contexts`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListContexts`` requests and continue to iterate + through the ``contexts`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListContextsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., metadata_service.ListContextsResponse], + request: metadata_service.ListContextsRequest, + response: metadata_service.ListContextsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListContextsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListContextsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListContextsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[metadata_service.ListContextsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[context.Context]: + for page in self.pages: + yield from page.contexts + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListContextsAsyncPager: + """A pager for iterating through ``list_contexts`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListContextsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``contexts`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListContexts`` requests and continue to iterate + through the ``contexts`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListContextsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[metadata_service.ListContextsResponse]], + request: metadata_service.ListContextsRequest, + response: metadata_service.ListContextsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListContextsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListContextsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListContextsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[metadata_service.ListContextsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[context.Context]: + async def async_generator(): + async for page in self.pages: + for response in page.contexts: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListExecutionsPager: + """A pager for iterating through ``list_executions`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``executions`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListExecutions`` requests and continue to iterate + through the ``executions`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., metadata_service.ListExecutionsResponse], + request: metadata_service.ListExecutionsRequest, + response: metadata_service.ListExecutionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListExecutionsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListExecutionsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[metadata_service.ListExecutionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[execution.Execution]: + for page in self.pages: + yield from page.executions + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListExecutionsAsyncPager: + """A pager for iterating through ``list_executions`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``executions`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListExecutions`` requests and continue to iterate + through the ``executions`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[metadata_service.ListExecutionsResponse]], + request: metadata_service.ListExecutionsRequest, + response: metadata_service.ListExecutionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListExecutionsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListExecutionsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[metadata_service.ListExecutionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[execution.Execution]: + async def async_generator(): + async for page in self.pages: + for response in page.executions: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListMetadataSchemasPager: + """A pager for iterating through ``list_metadata_schemas`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse` object, and + provides an ``__iter__`` method to iterate through its + ``metadata_schemas`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListMetadataSchemas`` requests and continue to iterate + through the ``metadata_schemas`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., metadata_service.ListMetadataSchemasResponse], + request: metadata_service.ListMetadataSchemasRequest, + response: metadata_service.ListMetadataSchemasResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListMetadataSchemasRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[metadata_service.ListMetadataSchemasResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[metadata_schema.MetadataSchema]: + for page in self.pages: + yield from page.metadata_schemas + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListMetadataSchemasAsyncPager: + """A pager for iterating through ``list_metadata_schemas`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``metadata_schemas`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListMetadataSchemas`` requests and continue to iterate + through the ``metadata_schemas`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[metadata_service.ListMetadataSchemasResponse]], + request: metadata_service.ListMetadataSchemasRequest, + response: metadata_service.ListMetadataSchemasResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = metadata_service.ListMetadataSchemasRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[metadata_service.ListMetadataSchemasResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[metadata_schema.MetadataSchema]: + async def async_generator(): + async for page in self.pages: + for response in page.metadata_schemas: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py new file mode 100644 index 0000000000..67031880cd --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import MetadataServiceTransport +from .grpc import MetadataServiceGrpcTransport +from .grpc_asyncio import MetadataServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[MetadataServiceTransport]] +_transport_registry['grpc'] = MetadataServiceGrpcTransport +_transport_registry['grpc_asyncio'] = MetadataServiceGrpcAsyncIOTransport + +__all__ = ( + 'MetadataServiceTransport', + 'MetadataServiceGrpcTransport', + 'MetadataServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py new file mode 100644 index 0000000000..76ef934c98 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py @@ -0,0 +1,480 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import context as gca_context +from google.cloud.aiplatform_v1beta1.types import execution +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import lineage_subgraph +from google.cloud.aiplatform_v1beta1.types import metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_schema as gca_metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import metadata_store +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class MetadataServiceTransport(abc.ABC): + """Abstract transport class for MetadataService.""" + + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) + + def __init__( + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + + # Save the credentials. + self._credentials = credentials + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_metadata_store: gapic_v1.method.wrap_method( + self.create_metadata_store, + default_timeout=None, + client_info=client_info, + ), + self.get_metadata_store: gapic_v1.method.wrap_method( + self.get_metadata_store, + default_timeout=None, + client_info=client_info, + ), + self.list_metadata_stores: gapic_v1.method.wrap_method( + self.list_metadata_stores, + default_timeout=None, + client_info=client_info, + ), + self.delete_metadata_store: gapic_v1.method.wrap_method( + self.delete_metadata_store, + default_timeout=None, + client_info=client_info, + ), + self.create_artifact: gapic_v1.method.wrap_method( + self.create_artifact, + default_timeout=None, + client_info=client_info, + ), + self.get_artifact: gapic_v1.method.wrap_method( + self.get_artifact, + default_timeout=None, + client_info=client_info, + ), + self.list_artifacts: gapic_v1.method.wrap_method( + self.list_artifacts, + default_timeout=None, + client_info=client_info, + ), + self.update_artifact: gapic_v1.method.wrap_method( + self.update_artifact, + default_timeout=None, + client_info=client_info, + ), + self.create_context: gapic_v1.method.wrap_method( + self.create_context, + default_timeout=None, + client_info=client_info, + ), + self.get_context: gapic_v1.method.wrap_method( + self.get_context, + default_timeout=None, + client_info=client_info, + ), + self.list_contexts: gapic_v1.method.wrap_method( + self.list_contexts, + default_timeout=None, + client_info=client_info, + ), + self.update_context: gapic_v1.method.wrap_method( + self.update_context, + default_timeout=None, + client_info=client_info, + ), + self.delete_context: gapic_v1.method.wrap_method( + self.delete_context, + default_timeout=None, + client_info=client_info, + ), + self.add_context_artifacts_and_executions: gapic_v1.method.wrap_method( + self.add_context_artifacts_and_executions, + default_timeout=None, + client_info=client_info, + ), + self.add_context_children: gapic_v1.method.wrap_method( + self.add_context_children, + default_timeout=None, + client_info=client_info, + ), + self.query_context_lineage_subgraph: gapic_v1.method.wrap_method( + self.query_context_lineage_subgraph, + default_timeout=None, + client_info=client_info, + ), + self.create_execution: gapic_v1.method.wrap_method( + self.create_execution, + default_timeout=None, + client_info=client_info, + ), + self.get_execution: gapic_v1.method.wrap_method( + self.get_execution, + default_timeout=None, + client_info=client_info, + ), + self.list_executions: gapic_v1.method.wrap_method( + self.list_executions, + default_timeout=None, + client_info=client_info, + ), + self.update_execution: gapic_v1.method.wrap_method( + self.update_execution, + default_timeout=None, + client_info=client_info, + ), + self.add_execution_events: gapic_v1.method.wrap_method( + self.add_execution_events, + default_timeout=None, + client_info=client_info, + ), + self.query_execution_inputs_and_outputs: gapic_v1.method.wrap_method( + self.query_execution_inputs_and_outputs, + default_timeout=None, + client_info=client_info, + ), + self.create_metadata_schema: gapic_v1.method.wrap_method( + self.create_metadata_schema, + default_timeout=None, + client_info=client_info, + ), + self.get_metadata_schema: gapic_v1.method.wrap_method( + self.get_metadata_schema, + default_timeout=None, + client_info=client_info, + ), + self.list_metadata_schemas: gapic_v1.method.wrap_method( + self.list_metadata_schemas, + default_timeout=None, + client_info=client_info, + ), + + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_metadata_store(self) -> typing.Callable[ + [metadata_service.CreateMetadataStoreRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def get_metadata_store(self) -> typing.Callable[ + [metadata_service.GetMetadataStoreRequest], + typing.Union[ + metadata_store.MetadataStore, + typing.Awaitable[metadata_store.MetadataStore] + ]]: + raise NotImplementedError() + + @property + def list_metadata_stores(self) -> typing.Callable[ + [metadata_service.ListMetadataStoresRequest], + typing.Union[ + metadata_service.ListMetadataStoresResponse, + typing.Awaitable[metadata_service.ListMetadataStoresResponse] + ]]: + raise NotImplementedError() + + @property + def delete_metadata_store(self) -> typing.Callable[ + [metadata_service.DeleteMetadataStoreRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def create_artifact(self) -> typing.Callable[ + [metadata_service.CreateArtifactRequest], + typing.Union[ + gca_artifact.Artifact, + typing.Awaitable[gca_artifact.Artifact] + ]]: + raise NotImplementedError() + + @property + def get_artifact(self) -> typing.Callable[ + [metadata_service.GetArtifactRequest], + typing.Union[ + artifact.Artifact, + typing.Awaitable[artifact.Artifact] + ]]: + raise NotImplementedError() + + @property + def list_artifacts(self) -> typing.Callable[ + [metadata_service.ListArtifactsRequest], + typing.Union[ + metadata_service.ListArtifactsResponse, + typing.Awaitable[metadata_service.ListArtifactsResponse] + ]]: + raise NotImplementedError() + + @property + def update_artifact(self) -> typing.Callable[ + [metadata_service.UpdateArtifactRequest], + typing.Union[ + gca_artifact.Artifact, + typing.Awaitable[gca_artifact.Artifact] + ]]: + raise NotImplementedError() + + @property + def create_context(self) -> typing.Callable[ + [metadata_service.CreateContextRequest], + typing.Union[ + gca_context.Context, + typing.Awaitable[gca_context.Context] + ]]: + raise NotImplementedError() + + @property + def get_context(self) -> typing.Callable[ + [metadata_service.GetContextRequest], + typing.Union[ + context.Context, + typing.Awaitable[context.Context] + ]]: + raise NotImplementedError() + + @property + def list_contexts(self) -> typing.Callable[ + [metadata_service.ListContextsRequest], + typing.Union[ + metadata_service.ListContextsResponse, + typing.Awaitable[metadata_service.ListContextsResponse] + ]]: + raise NotImplementedError() + + @property + def update_context(self) -> typing.Callable[ + [metadata_service.UpdateContextRequest], + typing.Union[ + gca_context.Context, + typing.Awaitable[gca_context.Context] + ]]: + raise NotImplementedError() + + @property + def delete_context(self) -> typing.Callable[ + [metadata_service.DeleteContextRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def add_context_artifacts_and_executions(self) -> typing.Callable[ + [metadata_service.AddContextArtifactsAndExecutionsRequest], + typing.Union[ + metadata_service.AddContextArtifactsAndExecutionsResponse, + typing.Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse] + ]]: + raise NotImplementedError() + + @property + def add_context_children(self) -> typing.Callable[ + [metadata_service.AddContextChildrenRequest], + typing.Union[ + metadata_service.AddContextChildrenResponse, + typing.Awaitable[metadata_service.AddContextChildrenResponse] + ]]: + raise NotImplementedError() + + @property + def query_context_lineage_subgraph(self) -> typing.Callable[ + [metadata_service.QueryContextLineageSubgraphRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph] + ]]: + raise NotImplementedError() + + @property + def create_execution(self) -> typing.Callable[ + [metadata_service.CreateExecutionRequest], + typing.Union[ + gca_execution.Execution, + typing.Awaitable[gca_execution.Execution] + ]]: + raise NotImplementedError() + + @property + def get_execution(self) -> typing.Callable[ + [metadata_service.GetExecutionRequest], + typing.Union[ + execution.Execution, + typing.Awaitable[execution.Execution] + ]]: + raise NotImplementedError() + + @property + def list_executions(self) -> typing.Callable[ + [metadata_service.ListExecutionsRequest], + typing.Union[ + metadata_service.ListExecutionsResponse, + typing.Awaitable[metadata_service.ListExecutionsResponse] + ]]: + raise NotImplementedError() + + @property + def update_execution(self) -> typing.Callable[ + [metadata_service.UpdateExecutionRequest], + typing.Union[ + gca_execution.Execution, + typing.Awaitable[gca_execution.Execution] + ]]: + raise NotImplementedError() + + @property + def add_execution_events(self) -> typing.Callable[ + [metadata_service.AddExecutionEventsRequest], + typing.Union[ + metadata_service.AddExecutionEventsResponse, + typing.Awaitable[metadata_service.AddExecutionEventsResponse] + ]]: + raise NotImplementedError() + + @property + def query_execution_inputs_and_outputs(self) -> typing.Callable[ + [metadata_service.QueryExecutionInputsAndOutputsRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph] + ]]: + raise NotImplementedError() + + @property + def create_metadata_schema(self) -> typing.Callable[ + [metadata_service.CreateMetadataSchemaRequest], + typing.Union[ + gca_metadata_schema.MetadataSchema, + typing.Awaitable[gca_metadata_schema.MetadataSchema] + ]]: + raise NotImplementedError() + + @property + def get_metadata_schema(self) -> typing.Callable[ + [metadata_service.GetMetadataSchemaRequest], + typing.Union[ + metadata_schema.MetadataSchema, + typing.Awaitable[metadata_schema.MetadataSchema] + ]]: + raise NotImplementedError() + + @property + def list_metadata_schemas(self) -> typing.Callable[ + [metadata_service.ListMetadataSchemasRequest], + typing.Union[ + metadata_service.ListMetadataSchemasResponse, + typing.Awaitable[metadata_service.ListMetadataSchemasResponse] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'MetadataServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py new file mode 100644 index 0000000000..7cc6484f91 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py @@ -0,0 +1,917 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import context as gca_context +from google.cloud.aiplatform_v1beta1.types import execution +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import lineage_subgraph +from google.cloud.aiplatform_v1beta1.types import metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_schema as gca_metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import metadata_store +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MetadataServiceTransport, DEFAULT_CLIENT_INFO + + +class MetadataServiceGrpcTransport(MetadataServiceTransport): + """gRPC backend transport for MetadataService. + + Service for reading and writing metadata entries. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_metadata_store(self) -> Callable[ + [metadata_service.CreateMetadataStoreRequest], + operations.Operation]: + r"""Return a callable for the create metadata store method over gRPC. + + Initializes a MetadataStore, including allocation of + resources. + + Returns: + Callable[[~.CreateMetadataStoreRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_metadata_store' not in self._stubs: + self._stubs['create_metadata_store'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataStore', + request_serializer=metadata_service.CreateMetadataStoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_metadata_store'] + + @property + def get_metadata_store(self) -> Callable[ + [metadata_service.GetMetadataStoreRequest], + metadata_store.MetadataStore]: + r"""Return a callable for the get metadata store method over gRPC. + + Retrieves a specific MetadataStore. + + Returns: + Callable[[~.GetMetadataStoreRequest], + ~.MetadataStore]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_metadata_store' not in self._stubs: + self._stubs['get_metadata_store'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataStore', + request_serializer=metadata_service.GetMetadataStoreRequest.serialize, + response_deserializer=metadata_store.MetadataStore.deserialize, + ) + return self._stubs['get_metadata_store'] + + @property + def list_metadata_stores(self) -> Callable[ + [metadata_service.ListMetadataStoresRequest], + metadata_service.ListMetadataStoresResponse]: + r"""Return a callable for the list metadata stores method over gRPC. + + Lists MetadataStores for a Location. + + Returns: + Callable[[~.ListMetadataStoresRequest], + ~.ListMetadataStoresResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_metadata_stores' not in self._stubs: + self._stubs['list_metadata_stores'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataStores', + request_serializer=metadata_service.ListMetadataStoresRequest.serialize, + response_deserializer=metadata_service.ListMetadataStoresResponse.deserialize, + ) + return self._stubs['list_metadata_stores'] + + @property + def delete_metadata_store(self) -> Callable[ + [metadata_service.DeleteMetadataStoreRequest], + operations.Operation]: + r"""Return a callable for the delete metadata store method over gRPC. + + Deletes a single MetadataStore. + + Returns: + Callable[[~.DeleteMetadataStoreRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_metadata_store' not in self._stubs: + self._stubs['delete_metadata_store'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteMetadataStore', + request_serializer=metadata_service.DeleteMetadataStoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_metadata_store'] + + @property + def create_artifact(self) -> Callable[ + [metadata_service.CreateArtifactRequest], + gca_artifact.Artifact]: + r"""Return a callable for the create artifact method over gRPC. + + Creates an Artifact associated with a MetadataStore. + + Returns: + Callable[[~.CreateArtifactRequest], + ~.Artifact]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_artifact' not in self._stubs: + self._stubs['create_artifact'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateArtifact', + request_serializer=metadata_service.CreateArtifactRequest.serialize, + response_deserializer=gca_artifact.Artifact.deserialize, + ) + return self._stubs['create_artifact'] + + @property + def get_artifact(self) -> Callable[ + [metadata_service.GetArtifactRequest], + artifact.Artifact]: + r"""Return a callable for the get artifact method over gRPC. + + Retrieves a specific Artifact. + + Returns: + Callable[[~.GetArtifactRequest], + ~.Artifact]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_artifact' not in self._stubs: + self._stubs['get_artifact'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetArtifact', + request_serializer=metadata_service.GetArtifactRequest.serialize, + response_deserializer=artifact.Artifact.deserialize, + ) + return self._stubs['get_artifact'] + + @property + def list_artifacts(self) -> Callable[ + [metadata_service.ListArtifactsRequest], + metadata_service.ListArtifactsResponse]: + r"""Return a callable for the list artifacts method over gRPC. + + Lists Artifacts in the MetadataStore. + + Returns: + Callable[[~.ListArtifactsRequest], + ~.ListArtifactsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_artifacts' not in self._stubs: + self._stubs['list_artifacts'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListArtifacts', + request_serializer=metadata_service.ListArtifactsRequest.serialize, + response_deserializer=metadata_service.ListArtifactsResponse.deserialize, + ) + return self._stubs['list_artifacts'] + + @property + def update_artifact(self) -> Callable[ + [metadata_service.UpdateArtifactRequest], + gca_artifact.Artifact]: + r"""Return a callable for the update artifact method over gRPC. + + Updates a stored Artifact. + + Returns: + Callable[[~.UpdateArtifactRequest], + ~.Artifact]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_artifact' not in self._stubs: + self._stubs['update_artifact'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateArtifact', + request_serializer=metadata_service.UpdateArtifactRequest.serialize, + response_deserializer=gca_artifact.Artifact.deserialize, + ) + return self._stubs['update_artifact'] + + @property + def create_context(self) -> Callable[ + [metadata_service.CreateContextRequest], + gca_context.Context]: + r"""Return a callable for the create context method over gRPC. + + Creates a Context associated with a MetadataStore. + + Returns: + Callable[[~.CreateContextRequest], + ~.Context]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_context' not in self._stubs: + self._stubs['create_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateContext', + request_serializer=metadata_service.CreateContextRequest.serialize, + response_deserializer=gca_context.Context.deserialize, + ) + return self._stubs['create_context'] + + @property + def get_context(self) -> Callable[ + [metadata_service.GetContextRequest], + context.Context]: + r"""Return a callable for the get context method over gRPC. + + Retrieves a specific Context. + + Returns: + Callable[[~.GetContextRequest], + ~.Context]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_context' not in self._stubs: + self._stubs['get_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetContext', + request_serializer=metadata_service.GetContextRequest.serialize, + response_deserializer=context.Context.deserialize, + ) + return self._stubs['get_context'] + + @property + def list_contexts(self) -> Callable[ + [metadata_service.ListContextsRequest], + metadata_service.ListContextsResponse]: + r"""Return a callable for the list contexts method over gRPC. + + Lists Contexts on the MetadataStore. + + Returns: + Callable[[~.ListContextsRequest], + ~.ListContextsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_contexts' not in self._stubs: + self._stubs['list_contexts'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListContexts', + request_serializer=metadata_service.ListContextsRequest.serialize, + response_deserializer=metadata_service.ListContextsResponse.deserialize, + ) + return self._stubs['list_contexts'] + + @property + def update_context(self) -> Callable[ + [metadata_service.UpdateContextRequest], + gca_context.Context]: + r"""Return a callable for the update context method over gRPC. + + Updates a stored Context. + + Returns: + Callable[[~.UpdateContextRequest], + ~.Context]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_context' not in self._stubs: + self._stubs['update_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateContext', + request_serializer=metadata_service.UpdateContextRequest.serialize, + response_deserializer=gca_context.Context.deserialize, + ) + return self._stubs['update_context'] + + @property + def delete_context(self) -> Callable[ + [metadata_service.DeleteContextRequest], + operations.Operation]: + r"""Return a callable for the delete context method over gRPC. + + Deletes a stored Context. + + Returns: + Callable[[~.DeleteContextRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_context' not in self._stubs: + self._stubs['delete_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteContext', + request_serializer=metadata_service.DeleteContextRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_context'] + + @property + def add_context_artifacts_and_executions(self) -> Callable[ + [metadata_service.AddContextArtifactsAndExecutionsRequest], + metadata_service.AddContextArtifactsAndExecutionsResponse]: + r"""Return a callable for the add context artifacts and + executions method over gRPC. + + Adds a set of Artifacts and Executions to a Context. + If any of the Artifacts or Executions have already been + added to a Context, they are simply skipped. + + Returns: + Callable[[~.AddContextArtifactsAndExecutionsRequest], + ~.AddContextArtifactsAndExecutionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'add_context_artifacts_and_executions' not in self._stubs: + self._stubs['add_context_artifacts_and_executions'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextArtifactsAndExecutions', + request_serializer=metadata_service.AddContextArtifactsAndExecutionsRequest.serialize, + response_deserializer=metadata_service.AddContextArtifactsAndExecutionsResponse.deserialize, + ) + return self._stubs['add_context_artifacts_and_executions'] + + @property + def add_context_children(self) -> Callable[ + [metadata_service.AddContextChildrenRequest], + metadata_service.AddContextChildrenResponse]: + r"""Return a callable for the add context children method over gRPC. + + Adds a set of Contexts as children to a parent Context. If any + of the child Contexts have already been added to the parent + Context, they are simply skipped. If this call would create a + cycle or cause any Context to have more than 10 parents, the + request will fail with INVALID_ARGUMENT error. + + Returns: + Callable[[~.AddContextChildrenRequest], + ~.AddContextChildrenResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'add_context_children' not in self._stubs: + self._stubs['add_context_children'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextChildren', + request_serializer=metadata_service.AddContextChildrenRequest.serialize, + response_deserializer=metadata_service.AddContextChildrenResponse.deserialize, + ) + return self._stubs['add_context_children'] + + @property + def query_context_lineage_subgraph(self) -> Callable[ + [metadata_service.QueryContextLineageSubgraphRequest], + lineage_subgraph.LineageSubgraph]: + r"""Return a callable for the query context lineage subgraph method over gRPC. + + Retrieves Artifacts and Executions within the + specified Context, connected by Event edges and returned + as a LineageSubgraph. + + Returns: + Callable[[~.QueryContextLineageSubgraphRequest], + ~.LineageSubgraph]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'query_context_lineage_subgraph' not in self._stubs: + self._stubs['query_context_lineage_subgraph'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/QueryContextLineageSubgraph', + request_serializer=metadata_service.QueryContextLineageSubgraphRequest.serialize, + response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, + ) + return self._stubs['query_context_lineage_subgraph'] + + @property + def create_execution(self) -> Callable[ + [metadata_service.CreateExecutionRequest], + gca_execution.Execution]: + r"""Return a callable for the create execution method over gRPC. + + Creates an Execution associated with a MetadataStore. + + Returns: + Callable[[~.CreateExecutionRequest], + ~.Execution]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_execution' not in self._stubs: + self._stubs['create_execution'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateExecution', + request_serializer=metadata_service.CreateExecutionRequest.serialize, + response_deserializer=gca_execution.Execution.deserialize, + ) + return self._stubs['create_execution'] + + @property + def get_execution(self) -> Callable[ + [metadata_service.GetExecutionRequest], + execution.Execution]: + r"""Return a callable for the get execution method over gRPC. + + Retrieves a specific Execution. + + Returns: + Callable[[~.GetExecutionRequest], + ~.Execution]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_execution' not in self._stubs: + self._stubs['get_execution'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetExecution', + request_serializer=metadata_service.GetExecutionRequest.serialize, + response_deserializer=execution.Execution.deserialize, + ) + return self._stubs['get_execution'] + + @property + def list_executions(self) -> Callable[ + [metadata_service.ListExecutionsRequest], + metadata_service.ListExecutionsResponse]: + r"""Return a callable for the list executions method over gRPC. + + Lists Executions in the MetadataStore. + + Returns: + Callable[[~.ListExecutionsRequest], + ~.ListExecutionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_executions' not in self._stubs: + self._stubs['list_executions'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListExecutions', + request_serializer=metadata_service.ListExecutionsRequest.serialize, + response_deserializer=metadata_service.ListExecutionsResponse.deserialize, + ) + return self._stubs['list_executions'] + + @property + def update_execution(self) -> Callable[ + [metadata_service.UpdateExecutionRequest], + gca_execution.Execution]: + r"""Return a callable for the update execution method over gRPC. + + Updates a stored Execution. + + Returns: + Callable[[~.UpdateExecutionRequest], + ~.Execution]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_execution' not in self._stubs: + self._stubs['update_execution'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateExecution', + request_serializer=metadata_service.UpdateExecutionRequest.serialize, + response_deserializer=gca_execution.Execution.deserialize, + ) + return self._stubs['update_execution'] + + @property + def add_execution_events(self) -> Callable[ + [metadata_service.AddExecutionEventsRequest], + metadata_service.AddExecutionEventsResponse]: + r"""Return a callable for the add execution events method over gRPC. + + Adds Events for denoting whether each Artifact was an + input or output for a given Execution. If any Events + already exist between the Execution and any of the + specified Artifacts they are simply skipped. + + Returns: + Callable[[~.AddExecutionEventsRequest], + ~.AddExecutionEventsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'add_execution_events' not in self._stubs: + self._stubs['add_execution_events'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/AddExecutionEvents', + request_serializer=metadata_service.AddExecutionEventsRequest.serialize, + response_deserializer=metadata_service.AddExecutionEventsResponse.deserialize, + ) + return self._stubs['add_execution_events'] + + @property + def query_execution_inputs_and_outputs(self) -> Callable[ + [metadata_service.QueryExecutionInputsAndOutputsRequest], + lineage_subgraph.LineageSubgraph]: + r"""Return a callable for the query execution inputs and + outputs method over gRPC. + + Obtains the set of input and output Artifacts for + this Execution, in the form of LineageSubgraph that also + contains the Execution and connecting Events. + + Returns: + Callable[[~.QueryExecutionInputsAndOutputsRequest], + ~.LineageSubgraph]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'query_execution_inputs_and_outputs' not in self._stubs: + self._stubs['query_execution_inputs_and_outputs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/QueryExecutionInputsAndOutputs', + request_serializer=metadata_service.QueryExecutionInputsAndOutputsRequest.serialize, + response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, + ) + return self._stubs['query_execution_inputs_and_outputs'] + + @property + def create_metadata_schema(self) -> Callable[ + [metadata_service.CreateMetadataSchemaRequest], + gca_metadata_schema.MetadataSchema]: + r"""Return a callable for the create metadata schema method over gRPC. + + Creates an MetadataSchema. + + Returns: + Callable[[~.CreateMetadataSchemaRequest], + ~.MetadataSchema]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_metadata_schema' not in self._stubs: + self._stubs['create_metadata_schema'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataSchema', + request_serializer=metadata_service.CreateMetadataSchemaRequest.serialize, + response_deserializer=gca_metadata_schema.MetadataSchema.deserialize, + ) + return self._stubs['create_metadata_schema'] + + @property + def get_metadata_schema(self) -> Callable[ + [metadata_service.GetMetadataSchemaRequest], + metadata_schema.MetadataSchema]: + r"""Return a callable for the get metadata schema method over gRPC. + + Retrieves a specific MetadataSchema. + + Returns: + Callable[[~.GetMetadataSchemaRequest], + ~.MetadataSchema]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_metadata_schema' not in self._stubs: + self._stubs['get_metadata_schema'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataSchema', + request_serializer=metadata_service.GetMetadataSchemaRequest.serialize, + response_deserializer=metadata_schema.MetadataSchema.deserialize, + ) + return self._stubs['get_metadata_schema'] + + @property + def list_metadata_schemas(self) -> Callable[ + [metadata_service.ListMetadataSchemasRequest], + metadata_service.ListMetadataSchemasResponse]: + r"""Return a callable for the list metadata schemas method over gRPC. + + Lists MetadataSchemas. + + Returns: + Callable[[~.ListMetadataSchemasRequest], + ~.ListMetadataSchemasResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_metadata_schemas' not in self._stubs: + self._stubs['list_metadata_schemas'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataSchemas', + request_serializer=metadata_service.ListMetadataSchemasRequest.serialize, + response_deserializer=metadata_service.ListMetadataSchemasResponse.deserialize, + ) + return self._stubs['list_metadata_schemas'] + + +__all__ = ( + 'MetadataServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..bedea761c0 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py @@ -0,0 +1,922 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import context as gca_context +from google.cloud.aiplatform_v1beta1.types import execution +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import lineage_subgraph +from google.cloud.aiplatform_v1beta1.types import metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_schema as gca_metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import metadata_store +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MetadataServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import MetadataServiceGrpcTransport + + +class MetadataServiceGrpcAsyncIOTransport(MetadataServiceTransport): + """gRPC AsyncIO backend transport for MetadataService. + + Service for reading and writing metadata entries. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_metadata_store(self) -> Callable[ + [metadata_service.CreateMetadataStoreRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create metadata store method over gRPC. + + Initializes a MetadataStore, including allocation of + resources. + + Returns: + Callable[[~.CreateMetadataStoreRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_metadata_store' not in self._stubs: + self._stubs['create_metadata_store'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataStore', + request_serializer=metadata_service.CreateMetadataStoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_metadata_store'] + + @property + def get_metadata_store(self) -> Callable[ + [metadata_service.GetMetadataStoreRequest], + Awaitable[metadata_store.MetadataStore]]: + r"""Return a callable for the get metadata store method over gRPC. + + Retrieves a specific MetadataStore. + + Returns: + Callable[[~.GetMetadataStoreRequest], + Awaitable[~.MetadataStore]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_metadata_store' not in self._stubs: + self._stubs['get_metadata_store'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataStore', + request_serializer=metadata_service.GetMetadataStoreRequest.serialize, + response_deserializer=metadata_store.MetadataStore.deserialize, + ) + return self._stubs['get_metadata_store'] + + @property + def list_metadata_stores(self) -> Callable[ + [metadata_service.ListMetadataStoresRequest], + Awaitable[metadata_service.ListMetadataStoresResponse]]: + r"""Return a callable for the list metadata stores method over gRPC. + + Lists MetadataStores for a Location. + + Returns: + Callable[[~.ListMetadataStoresRequest], + Awaitable[~.ListMetadataStoresResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_metadata_stores' not in self._stubs: + self._stubs['list_metadata_stores'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataStores', + request_serializer=metadata_service.ListMetadataStoresRequest.serialize, + response_deserializer=metadata_service.ListMetadataStoresResponse.deserialize, + ) + return self._stubs['list_metadata_stores'] + + @property + def delete_metadata_store(self) -> Callable[ + [metadata_service.DeleteMetadataStoreRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete metadata store method over gRPC. + + Deletes a single MetadataStore. + + Returns: + Callable[[~.DeleteMetadataStoreRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_metadata_store' not in self._stubs: + self._stubs['delete_metadata_store'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteMetadataStore', + request_serializer=metadata_service.DeleteMetadataStoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_metadata_store'] + + @property + def create_artifact(self) -> Callable[ + [metadata_service.CreateArtifactRequest], + Awaitable[gca_artifact.Artifact]]: + r"""Return a callable for the create artifact method over gRPC. + + Creates an Artifact associated with a MetadataStore. + + Returns: + Callable[[~.CreateArtifactRequest], + Awaitable[~.Artifact]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_artifact' not in self._stubs: + self._stubs['create_artifact'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateArtifact', + request_serializer=metadata_service.CreateArtifactRequest.serialize, + response_deserializer=gca_artifact.Artifact.deserialize, + ) + return self._stubs['create_artifact'] + + @property + def get_artifact(self) -> Callable[ + [metadata_service.GetArtifactRequest], + Awaitable[artifact.Artifact]]: + r"""Return a callable for the get artifact method over gRPC. + + Retrieves a specific Artifact. + + Returns: + Callable[[~.GetArtifactRequest], + Awaitable[~.Artifact]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_artifact' not in self._stubs: + self._stubs['get_artifact'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetArtifact', + request_serializer=metadata_service.GetArtifactRequest.serialize, + response_deserializer=artifact.Artifact.deserialize, + ) + return self._stubs['get_artifact'] + + @property + def list_artifacts(self) -> Callable[ + [metadata_service.ListArtifactsRequest], + Awaitable[metadata_service.ListArtifactsResponse]]: + r"""Return a callable for the list artifacts method over gRPC. + + Lists Artifacts in the MetadataStore. + + Returns: + Callable[[~.ListArtifactsRequest], + Awaitable[~.ListArtifactsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_artifacts' not in self._stubs: + self._stubs['list_artifacts'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListArtifacts', + request_serializer=metadata_service.ListArtifactsRequest.serialize, + response_deserializer=metadata_service.ListArtifactsResponse.deserialize, + ) + return self._stubs['list_artifacts'] + + @property + def update_artifact(self) -> Callable[ + [metadata_service.UpdateArtifactRequest], + Awaitable[gca_artifact.Artifact]]: + r"""Return a callable for the update artifact method over gRPC. + + Updates a stored Artifact. + + Returns: + Callable[[~.UpdateArtifactRequest], + Awaitable[~.Artifact]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_artifact' not in self._stubs: + self._stubs['update_artifact'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateArtifact', + request_serializer=metadata_service.UpdateArtifactRequest.serialize, + response_deserializer=gca_artifact.Artifact.deserialize, + ) + return self._stubs['update_artifact'] + + @property + def create_context(self) -> Callable[ + [metadata_service.CreateContextRequest], + Awaitable[gca_context.Context]]: + r"""Return a callable for the create context method over gRPC. + + Creates a Context associated with a MetadataStore. + + Returns: + Callable[[~.CreateContextRequest], + Awaitable[~.Context]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_context' not in self._stubs: + self._stubs['create_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateContext', + request_serializer=metadata_service.CreateContextRequest.serialize, + response_deserializer=gca_context.Context.deserialize, + ) + return self._stubs['create_context'] + + @property + def get_context(self) -> Callable[ + [metadata_service.GetContextRequest], + Awaitable[context.Context]]: + r"""Return a callable for the get context method over gRPC. + + Retrieves a specific Context. + + Returns: + Callable[[~.GetContextRequest], + Awaitable[~.Context]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_context' not in self._stubs: + self._stubs['get_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetContext', + request_serializer=metadata_service.GetContextRequest.serialize, + response_deserializer=context.Context.deserialize, + ) + return self._stubs['get_context'] + + @property + def list_contexts(self) -> Callable[ + [metadata_service.ListContextsRequest], + Awaitable[metadata_service.ListContextsResponse]]: + r"""Return a callable for the list contexts method over gRPC. + + Lists Contexts on the MetadataStore. + + Returns: + Callable[[~.ListContextsRequest], + Awaitable[~.ListContextsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_contexts' not in self._stubs: + self._stubs['list_contexts'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListContexts', + request_serializer=metadata_service.ListContextsRequest.serialize, + response_deserializer=metadata_service.ListContextsResponse.deserialize, + ) + return self._stubs['list_contexts'] + + @property + def update_context(self) -> Callable[ + [metadata_service.UpdateContextRequest], + Awaitable[gca_context.Context]]: + r"""Return a callable for the update context method over gRPC. + + Updates a stored Context. + + Returns: + Callable[[~.UpdateContextRequest], + Awaitable[~.Context]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_context' not in self._stubs: + self._stubs['update_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateContext', + request_serializer=metadata_service.UpdateContextRequest.serialize, + response_deserializer=gca_context.Context.deserialize, + ) + return self._stubs['update_context'] + + @property + def delete_context(self) -> Callable[ + [metadata_service.DeleteContextRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete context method over gRPC. + + Deletes a stored Context. + + Returns: + Callable[[~.DeleteContextRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_context' not in self._stubs: + self._stubs['delete_context'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteContext', + request_serializer=metadata_service.DeleteContextRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_context'] + + @property + def add_context_artifacts_and_executions(self) -> Callable[ + [metadata_service.AddContextArtifactsAndExecutionsRequest], + Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse]]: + r"""Return a callable for the add context artifacts and + executions method over gRPC. + + Adds a set of Artifacts and Executions to a Context. + If any of the Artifacts or Executions have already been + added to a Context, they are simply skipped. + + Returns: + Callable[[~.AddContextArtifactsAndExecutionsRequest], + Awaitable[~.AddContextArtifactsAndExecutionsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'add_context_artifacts_and_executions' not in self._stubs: + self._stubs['add_context_artifacts_and_executions'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextArtifactsAndExecutions', + request_serializer=metadata_service.AddContextArtifactsAndExecutionsRequest.serialize, + response_deserializer=metadata_service.AddContextArtifactsAndExecutionsResponse.deserialize, + ) + return self._stubs['add_context_artifacts_and_executions'] + + @property + def add_context_children(self) -> Callable[ + [metadata_service.AddContextChildrenRequest], + Awaitable[metadata_service.AddContextChildrenResponse]]: + r"""Return a callable for the add context children method over gRPC. + + Adds a set of Contexts as children to a parent Context. If any + of the child Contexts have already been added to the parent + Context, they are simply skipped. If this call would create a + cycle or cause any Context to have more than 10 parents, the + request will fail with INVALID_ARGUMENT error. + + Returns: + Callable[[~.AddContextChildrenRequest], + Awaitable[~.AddContextChildrenResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'add_context_children' not in self._stubs: + self._stubs['add_context_children'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextChildren', + request_serializer=metadata_service.AddContextChildrenRequest.serialize, + response_deserializer=metadata_service.AddContextChildrenResponse.deserialize, + ) + return self._stubs['add_context_children'] + + @property + def query_context_lineage_subgraph(self) -> Callable[ + [metadata_service.QueryContextLineageSubgraphRequest], + Awaitable[lineage_subgraph.LineageSubgraph]]: + r"""Return a callable for the query context lineage subgraph method over gRPC. + + Retrieves Artifacts and Executions within the + specified Context, connected by Event edges and returned + as a LineageSubgraph. + + Returns: + Callable[[~.QueryContextLineageSubgraphRequest], + Awaitable[~.LineageSubgraph]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'query_context_lineage_subgraph' not in self._stubs: + self._stubs['query_context_lineage_subgraph'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/QueryContextLineageSubgraph', + request_serializer=metadata_service.QueryContextLineageSubgraphRequest.serialize, + response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, + ) + return self._stubs['query_context_lineage_subgraph'] + + @property + def create_execution(self) -> Callable[ + [metadata_service.CreateExecutionRequest], + Awaitable[gca_execution.Execution]]: + r"""Return a callable for the create execution method over gRPC. + + Creates an Execution associated with a MetadataStore. + + Returns: + Callable[[~.CreateExecutionRequest], + Awaitable[~.Execution]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_execution' not in self._stubs: + self._stubs['create_execution'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateExecution', + request_serializer=metadata_service.CreateExecutionRequest.serialize, + response_deserializer=gca_execution.Execution.deserialize, + ) + return self._stubs['create_execution'] + + @property + def get_execution(self) -> Callable[ + [metadata_service.GetExecutionRequest], + Awaitable[execution.Execution]]: + r"""Return a callable for the get execution method over gRPC. + + Retrieves a specific Execution. + + Returns: + Callable[[~.GetExecutionRequest], + Awaitable[~.Execution]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_execution' not in self._stubs: + self._stubs['get_execution'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetExecution', + request_serializer=metadata_service.GetExecutionRequest.serialize, + response_deserializer=execution.Execution.deserialize, + ) + return self._stubs['get_execution'] + + @property + def list_executions(self) -> Callable[ + [metadata_service.ListExecutionsRequest], + Awaitable[metadata_service.ListExecutionsResponse]]: + r"""Return a callable for the list executions method over gRPC. + + Lists Executions in the MetadataStore. + + Returns: + Callable[[~.ListExecutionsRequest], + Awaitable[~.ListExecutionsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_executions' not in self._stubs: + self._stubs['list_executions'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListExecutions', + request_serializer=metadata_service.ListExecutionsRequest.serialize, + response_deserializer=metadata_service.ListExecutionsResponse.deserialize, + ) + return self._stubs['list_executions'] + + @property + def update_execution(self) -> Callable[ + [metadata_service.UpdateExecutionRequest], + Awaitable[gca_execution.Execution]]: + r"""Return a callable for the update execution method over gRPC. + + Updates a stored Execution. + + Returns: + Callable[[~.UpdateExecutionRequest], + Awaitable[~.Execution]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_execution' not in self._stubs: + self._stubs['update_execution'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateExecution', + request_serializer=metadata_service.UpdateExecutionRequest.serialize, + response_deserializer=gca_execution.Execution.deserialize, + ) + return self._stubs['update_execution'] + + @property + def add_execution_events(self) -> Callable[ + [metadata_service.AddExecutionEventsRequest], + Awaitable[metadata_service.AddExecutionEventsResponse]]: + r"""Return a callable for the add execution events method over gRPC. + + Adds Events for denoting whether each Artifact was an + input or output for a given Execution. If any Events + already exist between the Execution and any of the + specified Artifacts they are simply skipped. + + Returns: + Callable[[~.AddExecutionEventsRequest], + Awaitable[~.AddExecutionEventsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'add_execution_events' not in self._stubs: + self._stubs['add_execution_events'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/AddExecutionEvents', + request_serializer=metadata_service.AddExecutionEventsRequest.serialize, + response_deserializer=metadata_service.AddExecutionEventsResponse.deserialize, + ) + return self._stubs['add_execution_events'] + + @property + def query_execution_inputs_and_outputs(self) -> Callable[ + [metadata_service.QueryExecutionInputsAndOutputsRequest], + Awaitable[lineage_subgraph.LineageSubgraph]]: + r"""Return a callable for the query execution inputs and + outputs method over gRPC. + + Obtains the set of input and output Artifacts for + this Execution, in the form of LineageSubgraph that also + contains the Execution and connecting Events. + + Returns: + Callable[[~.QueryExecutionInputsAndOutputsRequest], + Awaitable[~.LineageSubgraph]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'query_execution_inputs_and_outputs' not in self._stubs: + self._stubs['query_execution_inputs_and_outputs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/QueryExecutionInputsAndOutputs', + request_serializer=metadata_service.QueryExecutionInputsAndOutputsRequest.serialize, + response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, + ) + return self._stubs['query_execution_inputs_and_outputs'] + + @property + def create_metadata_schema(self) -> Callable[ + [metadata_service.CreateMetadataSchemaRequest], + Awaitable[gca_metadata_schema.MetadataSchema]]: + r"""Return a callable for the create metadata schema method over gRPC. + + Creates an MetadataSchema. + + Returns: + Callable[[~.CreateMetadataSchemaRequest], + Awaitable[~.MetadataSchema]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_metadata_schema' not in self._stubs: + self._stubs['create_metadata_schema'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataSchema', + request_serializer=metadata_service.CreateMetadataSchemaRequest.serialize, + response_deserializer=gca_metadata_schema.MetadataSchema.deserialize, + ) + return self._stubs['create_metadata_schema'] + + @property + def get_metadata_schema(self) -> Callable[ + [metadata_service.GetMetadataSchemaRequest], + Awaitable[metadata_schema.MetadataSchema]]: + r"""Return a callable for the get metadata schema method over gRPC. + + Retrieves a specific MetadataSchema. + + Returns: + Callable[[~.GetMetadataSchemaRequest], + Awaitable[~.MetadataSchema]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_metadata_schema' not in self._stubs: + self._stubs['get_metadata_schema'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataSchema', + request_serializer=metadata_service.GetMetadataSchemaRequest.serialize, + response_deserializer=metadata_schema.MetadataSchema.deserialize, + ) + return self._stubs['get_metadata_schema'] + + @property + def list_metadata_schemas(self) -> Callable[ + [metadata_service.ListMetadataSchemasRequest], + Awaitable[metadata_service.ListMetadataSchemasResponse]]: + r"""Return a callable for the list metadata schemas method over gRPC. + + Lists MetadataSchemas. + + Returns: + Callable[[~.ListMetadataSchemasRequest], + Awaitable[~.ListMetadataSchemasResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_metadata_schemas' not in self._stubs: + self._stubs['list_metadata_schemas'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataSchemas', + request_serializer=metadata_service.ListMetadataSchemasRequest.serialize, + response_deserializer=metadata_service.ListMetadataSchemasResponse.deserialize, + ) + return self._stubs['list_metadata_schemas'] + + +__all__ = ( + 'MetadataServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py index 1d6216d1f7..c533a12b45 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import MigrationServiceAsyncClient __all__ = ( - "MigrationServiceClient", - "MigrationServiceAsyncClient", + 'MigrationServiceClient', + 'MigrationServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index c4db3f14d7..d79e43c9c1 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -51,9 +51,7 @@ class MigrationServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) - parse_annotated_dataset_path = staticmethod( - MigrationServiceClient.parse_annotated_dataset_path - ) + parse_annotated_dataset_path = staticmethod(MigrationServiceClient.parse_annotated_dataset_path) dataset_path = staticmethod(MigrationServiceClient.dataset_path) parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) dataset_path = staticmethod(MigrationServiceClient.dataset_path) @@ -67,34 +65,20 @@ class MigrationServiceAsyncClient: version_path = staticmethod(MigrationServiceClient.version_path) parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) - common_billing_account_path = staticmethod( - MigrationServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - MigrationServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(MigrationServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(MigrationServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - MigrationServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(MigrationServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - MigrationServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - MigrationServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(MigrationServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(MigrationServiceClient.parse_common_organization_path) common_project_path = staticmethod(MigrationServiceClient.common_project_path) - parse_common_project_path = staticmethod( - MigrationServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(MigrationServiceClient.parse_common_project_path) common_location_path = staticmethod(MigrationServiceClient.common_location_path) - parse_common_location_path = staticmethod( - MigrationServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(MigrationServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -137,18 +121,14 @@ def transport(self) -> MigrationServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient) - ) + get_transport_class = functools.partial(type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -187,17 +167,17 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def search_migratable_resources( - self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesAsyncPager: + async def search_migratable_resources(self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -238,10 +218,8 @@ async def search_migratable_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = migration_service.SearchMigratableResourcesRequest(request) @@ -262,33 +240,40 @@ async def search_migratable_resources( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.SearchMigratableResourcesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def batch_migrate_resources( - self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[ - migration_service.MigrateResourceRequest - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def batch_migrate_resources(self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -337,10 +322,8 @@ async def batch_migrate_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = migration_service.BatchMigrateResourcesRequest(request) @@ -364,11 +347,18 @@ async def batch_migrate_resources( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -382,14 +372,21 @@ async def batch_migrate_resources( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("MigrationServiceAsyncClient",) +__all__ = ( + 'MigrationServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 501f21183f..a636962692 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -50,14 +50,13 @@ class MigrationServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry['grpc'] = MigrationServiceGrpcTransport + _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[MigrationServiceTransport]] - _transport_registry["grpc"] = MigrationServiceGrpcTransport - _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[MigrationServiceTransport]: """Return an appropriate transport class. Args: @@ -111,7 +110,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -146,8 +145,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MigrationServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -162,183 +162,143 @@ def transport(self) -> MigrationServiceTransport: return self._transport @staticmethod - def annotated_dataset_path( - project: str, dataset: str, annotated_dataset: str, - ) -> str: + def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: """Return a fully-qualified annotated_dataset string.""" - return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( - project=project, dataset=dataset, annotated_dataset=annotated_dataset, - ) + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) @staticmethod - def parse_annotated_dataset_path(path: str) -> Dict[str, str]: + def parse_annotated_dataset_path(path: str) -> Dict[str,str]: """Parse a annotated_dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: + def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def version_path(project: str, model: str, version: str,) -> str: + def version_path(project: str,model: str,version: str,) -> str: """Return a fully-qualified version string.""" - return "projects/{project}/models/{model}/versions/{version}".format( - project=project, model=model, version=version, - ) + return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) @staticmethod - def parse_version_path(path: str) -> Dict[str, str]: + def parse_version_path(path: str) -> Dict[str,str]: """Parse a version path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -382,9 +342,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -394,9 +352,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -408,9 +364,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -422,10 +376,8 @@ def __init__( if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -444,15 +396,14 @@ def __init__( client_info=client_info, ) - def search_migratable_resources( - self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesPager: + def search_migratable_resources(self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -493,10 +444,8 @@ def search_migratable_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a migration_service.SearchMigratableResourcesRequest. @@ -513,40 +462,45 @@ def search_migratable_resources( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.search_migratable_resources - ] + rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchMigratableResourcesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def batch_migrate_resources( - self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[ - migration_service.MigrateResourceRequest - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def batch_migrate_resources(self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -595,10 +549,8 @@ def batch_migrate_resources( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a migration_service.BatchMigrateResourcesRequest. @@ -622,11 +574,18 @@ def batch_migrate_resources( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation.from_gapic( @@ -640,14 +599,21 @@ def batch_migrate_resources( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("MigrationServiceClient",) +__all__ = ( + 'MigrationServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py index f0a1dfa43f..d25339203b 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import migratable_resource from google.cloud.aiplatform_v1beta1.types import migration_service @@ -47,15 +38,12 @@ class SearchMigratableResourcesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., migration_service.SearchMigratableResourcesResponse], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: yield from page.migratable_resources def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class SearchMigratableResourcesAsyncPager: @@ -109,17 +97,12 @@ class SearchMigratableResourcesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[migration_service.SearchMigratableResourcesResponse] - ], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[migration_service.SearchMigratableResourcesResponse]], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -141,9 +124,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + async def pages(self) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -159,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py index 38c72756f6..9fb765fdcc 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] -_transport_registry["grpc"] = MigrationServiceGrpcTransport -_transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = MigrationServiceGrpcTransport +_transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport __all__ = ( - "MigrationServiceTransport", - "MigrationServiceGrpcTransport", - "MigrationServiceGrpcAsyncIOTransport", + 'MigrationServiceTransport', + 'MigrationServiceGrpcTransport', + 'MigrationServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py index cbcb288489..ba00adae0e 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -33,29 +33,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class MigrationServiceTransport(abc.ABC): """Abstract transport class for MigrationService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -71,40 +71,38 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -118,6 +116,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + } @property @@ -126,25 +125,24 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def search_migratable_resources( - self, - ) -> typing.Callable[ - [migration_service.SearchMigratableResourcesRequest], - typing.Union[ - migration_service.SearchMigratableResourcesResponse, - typing.Awaitable[migration_service.SearchMigratableResourcesResponse], - ], - ]: + def search_migratable_resources(self) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse] + ]]: raise NotImplementedError() @property - def batch_migrate_resources( - self, - ) -> typing.Callable[ - [migration_service.BatchMigrateResourcesRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def batch_migrate_resources(self) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() -__all__ = ("MigrationServiceTransport",) +__all__ = ( + 'MigrationServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index 6789c12718..28a61272bf 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -47,24 +47,21 @@ class MigrationServiceGrpcTransport(MigrationServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -110,7 +107,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -118,70 +118,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -189,32 +169,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -244,12 +212,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -261,18 +230,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def search_migratable_resources( - self, - ) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - migration_service.SearchMigratableResourcesResponse, - ]: + def search_migratable_resources(self) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -290,20 +258,18 @@ def search_migratable_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "search_migratable_resources" not in self._stubs: - self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", + if 'search_migratable_resources' not in self._stubs: + self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs["search_migratable_resources"] + return self._stubs['search_migratable_resources'] @property - def batch_migrate_resources( - self, - ) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], operations.Operation - ]: + def batch_migrate_resources(self) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + operations.Operation]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -320,13 +286,15 @@ def batch_migrate_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "batch_migrate_resources" not in self._stubs: - self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", + if 'batch_migrate_resources' not in self._stubs: + self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["batch_migrate_resources"] + return self._stubs['batch_migrate_resources'] -__all__ = ("MigrationServiceGrpcTransport",) +__all__ = ( + 'MigrationServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index 33e96e7170..4648d86616 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import migration_service @@ -54,18 +54,16 @@ class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -91,24 +89,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -143,10 +139,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -155,7 +151,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -163,70 +162,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -234,18 +213,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -274,12 +243,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def search_migratable_resources( - self, - ) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - Awaitable[migration_service.SearchMigratableResourcesResponse], - ]: + def search_migratable_resources(self) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse]]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -297,21 +263,18 @@ def search_migratable_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "search_migratable_resources" not in self._stubs: - self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", + if 'search_migratable_resources' not in self._stubs: + self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs["search_migratable_resources"] + return self._stubs['search_migratable_resources'] @property - def batch_migrate_resources( - self, - ) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - Awaitable[operations.Operation], - ]: + def batch_migrate_resources(self) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -328,13 +291,15 @@ def batch_migrate_resources( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "batch_migrate_resources" not in self._stubs: - self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", + if 'batch_migrate_resources' not in self._stubs: + self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["batch_migrate_resources"] + return self._stubs['batch_migrate_resources'] -__all__ = ("MigrationServiceGrpcAsyncIOTransport",) +__all__ = ( + 'MigrationServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py index b39295ebfe..3ee8fc6e9e 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import ModelServiceAsyncClient __all__ = ( - "ModelServiceClient", - "ModelServiceAsyncClient", + 'ModelServiceClient', + 'ModelServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index a901ead2b1..72cfd1e4e4 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -63,44 +63,26 @@ class ModelServiceAsyncClient: model_path = staticmethod(ModelServiceClient.model_path) parse_model_path = staticmethod(ModelServiceClient.parse_model_path) model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) - parse_model_evaluation_path = staticmethod( - ModelServiceClient.parse_model_evaluation_path - ) - model_evaluation_slice_path = staticmethod( - ModelServiceClient.model_evaluation_slice_path - ) - parse_model_evaluation_slice_path = staticmethod( - ModelServiceClient.parse_model_evaluation_slice_path - ) + parse_model_evaluation_path = staticmethod(ModelServiceClient.parse_model_evaluation_path) + model_evaluation_slice_path = staticmethod(ModelServiceClient.model_evaluation_slice_path) + parse_model_evaluation_slice_path = staticmethod(ModelServiceClient.parse_model_evaluation_slice_path) training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod( - ModelServiceClient.parse_training_pipeline_path - ) + parse_training_pipeline_path = staticmethod(ModelServiceClient.parse_training_pipeline_path) - common_billing_account_path = staticmethod( - ModelServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - ModelServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(ModelServiceClient.common_folder_path) parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) common_organization_path = staticmethod(ModelServiceClient.common_organization_path) - parse_common_organization_path = staticmethod( - ModelServiceClient.parse_common_organization_path - ) + parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) common_project_path = staticmethod(ModelServiceClient.common_project_path) - parse_common_project_path = staticmethod( - ModelServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) common_location_path = staticmethod(ModelServiceClient.common_location_path) - parse_common_location_path = staticmethod( - ModelServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -143,18 +125,14 @@ def transport(self) -> ModelServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(ModelServiceClient).get_transport_class, type(ModelServiceClient) - ) + get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -193,18 +171,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def upload_model( - self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def upload_model(self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Uploads a Model artifact into AI Platform. Args: @@ -247,10 +225,8 @@ async def upload_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.UploadModelRequest(request) @@ -273,11 +249,18 @@ async def upload_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -290,15 +273,14 @@ async def upload_model( # Done; return the response. return response - async def get_model( - self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + async def get_model(self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -328,10 +310,8 @@ async def get_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.GetModelRequest(request) @@ -352,24 +332,30 @@ async def get_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_models( - self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: + async def list_models(self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: r"""Lists Models in a Location. Args: @@ -405,10 +391,8 @@ async def list_models( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ListModelsRequest(request) @@ -429,31 +413,40 @@ async def list_models( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def update_model( - self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + async def update_model(self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -491,10 +484,8 @@ async def update_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.UpdateModelRequest(request) @@ -517,26 +508,30 @@ async def update_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("model.name", request.model.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('model.name', request.model.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def delete_model( - self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_model(self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -584,10 +579,8 @@ async def delete_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.DeleteModelRequest(request) @@ -608,11 +601,18 @@ async def delete_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -625,16 +625,15 @@ async def delete_model( # Done; return the response. return response - async def export_model( - self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_model(self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -682,10 +681,8 @@ async def export_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ExportModelRequest(request) @@ -708,11 +705,18 @@ async def export_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -725,15 +729,14 @@ async def export_model( # Done; return the response. return response - async def get_model_evaluation( - self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + async def get_model_evaluation(self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -768,10 +771,8 @@ async def get_model_evaluation( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.GetModelEvaluationRequest(request) @@ -792,24 +793,30 @@ async def get_model_evaluation( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_model_evaluations( - self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsAsyncPager: + async def list_model_evaluations(self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: r"""Lists ModelEvaluations in a Model. Args: @@ -845,10 +852,8 @@ async def list_model_evaluations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ListModelEvaluationsRequest(request) @@ -869,30 +874,39 @@ async def list_model_evaluations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def get_model_evaluation_slice( - self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + async def get_model_evaluation_slice(self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -927,10 +941,8 @@ async def get_model_evaluation_slice( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.GetModelEvaluationSliceRequest(request) @@ -951,24 +963,30 @@ async def get_model_evaluation_slice( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_model_evaluation_slices( - self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesAsyncPager: + async def list_model_evaluation_slices(self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1004,10 +1022,8 @@ async def list_model_evaluation_slices( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = model_service.ListModelEvaluationSlicesRequest(request) @@ -1028,30 +1044,47 @@ async def list_model_evaluation_slices( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationSlicesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("ModelServiceAsyncClient",) +__all__ = ( + 'ModelServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 8b14e16e0b..29e081bc10 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -61,12 +61,13 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry["grpc"] = ModelServiceGrpcTransport - _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport + _transport_registry['grpc'] = ModelServiceGrpcTransport + _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +118,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,8 +153,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: ModelServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -168,162 +170,121 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_evaluation_path( - project: str, location: str, model: str, evaluation: str, - ) -> str: + def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: """Return a fully-qualified model_evaluation string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( - project=project, location=location, model=model, evaluation=evaluation, - ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) @staticmethod - def parse_model_evaluation_path(path: str) -> Dict[str, str]: + def parse_model_evaluation_path(path: str) -> Dict[str,str]: """Parse a model_evaluation path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_evaluation_slice_path( - project: str, location: str, model: str, evaluation: str, slice: str, - ) -> str: + def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: """Return a fully-qualified model_evaluation_slice string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( - project=project, - location=location, - model=model, - evaluation=evaluation, - slice=slice, - ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) @staticmethod - def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: + def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: """Parse a model_evaluation_slice path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path( - project: str, location: str, training_pipeline: str, - ) -> str: + def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str, str]: + def parse_training_pipeline_path(path: str) -> Dict[str,str]: """Parse a training_pipeline path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -367,9 +328,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -379,9 +338,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -393,9 +350,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -407,10 +362,8 @@ def __init__( if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -429,16 +382,15 @@ def __init__( client_info=client_info, ) - def upload_model( - self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def upload_model(self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -481,10 +433,8 @@ def upload_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.UploadModelRequest. @@ -508,11 +458,18 @@ def upload_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -525,15 +482,14 @@ def upload_model( # Done; return the response. return response - def get_model( - self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model(self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -563,10 +519,8 @@ def get_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -588,24 +542,30 @@ def get_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_models( - self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models(self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -641,10 +601,8 @@ def list_models( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -666,31 +624,40 @@ def list_models( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def update_model( - self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model(self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -728,10 +695,8 @@ def update_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateModelRequest. @@ -755,26 +720,30 @@ def update_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("model.name", request.model.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('model.name', request.model.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def delete_model( - self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_model(self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -822,10 +791,8 @@ def delete_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteModelRequest. @@ -847,11 +814,18 @@ def delete_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -864,16 +838,15 @@ def delete_model( # Done; return the response. return response - def export_model( - self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_model(self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -921,10 +894,8 @@ def export_model( # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ExportModelRequest. @@ -948,11 +919,18 @@ def export_model( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -965,15 +943,14 @@ def export_model( # Done; return the response. return response - def get_model_evaluation( - self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation(self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -1008,10 +985,8 @@ def get_model_evaluation( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationRequest. @@ -1033,24 +1008,30 @@ def get_model_evaluation( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_model_evaluations( - self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations(self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -1086,10 +1067,8 @@ def list_model_evaluations( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationsRequest. @@ -1111,30 +1090,39 @@ def list_model_evaluations( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice( - self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice(self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -1169,10 +1157,8 @@ def get_model_evaluation_slice( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationSliceRequest. @@ -1189,31 +1175,35 @@ def get_model_evaluation_slice( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.get_model_evaluation_slice - ] + rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_model_evaluation_slices( - self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices(self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1249,10 +1239,8 @@ def list_model_evaluation_slices( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationSlicesRequest. @@ -1269,37 +1257,52 @@ def list_model_evaluation_slices( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_model_evaluation_slices - ] + rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("ModelServiceClient",) +__all__ = ( + 'ModelServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index eb547a5f9f..c4d4d8696b 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model_evaluation @@ -49,15 +40,12 @@ class ListModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., model_service.ListModelsResponse], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -91,7 +79,7 @@ def __iter__(self) -> Iterable[model.Model]: yield from page.models def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelsAsyncPager: @@ -111,15 +99,12 @@ class ListModelsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -157,7 +142,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationsPager: @@ -177,15 +162,12 @@ class ListModelEvaluationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., model_service.ListModelEvaluationsResponse], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., model_service.ListModelEvaluationsResponse], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -219,7 +201,7 @@ def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: yield from page.model_evaluations def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationsAsyncPager: @@ -239,15 +221,12 @@ class ListModelEvaluationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -285,7 +264,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesPager: @@ -305,15 +284,12 @@ class ListModelEvaluationSlicesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., model_service.ListModelEvaluationSlicesResponse], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -347,7 +323,7 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: yield from page.model_evaluation_slices def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesAsyncPager: @@ -367,17 +343,12 @@ class ListModelEvaluationSlicesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] - ], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationSlicesResponse]], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -399,9 +370,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -417,4 +386,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py index 5d1cb51abc..833862a1d6 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry["grpc"] = ModelServiceGrpcTransport -_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = ModelServiceGrpcTransport +_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport __all__ = ( - "ModelServiceTransport", - "ModelServiceGrpcTransport", - "ModelServiceGrpcAsyncIOTransport", + 'ModelServiceTransport', + 'ModelServiceGrpcTransport', + 'ModelServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index 2f87fc98dd..40426aa4bd 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -37,29 +37,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -75,63 +75,75 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.upload_model: gapic_v1.method.wrap_method( - self.upload_model, default_timeout=5.0, client_info=client_info, + self.upload_model, + default_timeout=5.0, + client_info=client_info, ), self.get_model: gapic_v1.method.wrap_method( - self.get_model, default_timeout=5.0, client_info=client_info, + self.get_model, + default_timeout=5.0, + client_info=client_info, ), self.list_models: gapic_v1.method.wrap_method( - self.list_models, default_timeout=5.0, client_info=client_info, + self.list_models, + default_timeout=5.0, + client_info=client_info, ), self.update_model: gapic_v1.method.wrap_method( - self.update_model, default_timeout=5.0, client_info=client_info, + self.update_model, + default_timeout=5.0, + client_info=client_info, ), self.delete_model: gapic_v1.method.wrap_method( - self.delete_model, default_timeout=5.0, client_info=client_info, + self.delete_model, + default_timeout=5.0, + client_info=client_info, ), self.export_model: gapic_v1.method.wrap_method( - self.export_model, default_timeout=5.0, client_info=client_info, + self.export_model, + default_timeout=5.0, + client_info=client_info, ), self.get_model_evaluation: gapic_v1.method.wrap_method( - self.get_model_evaluation, default_timeout=5.0, client_info=client_info, + self.get_model_evaluation, + default_timeout=5.0, + client_info=client_info, ), self.list_model_evaluations: gapic_v1.method.wrap_method( self.list_model_evaluations, @@ -148,6 +160,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), + } @property @@ -156,109 +169,96 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def upload_model( - self, - ) -> typing.Callable[ - [model_service.UploadModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def upload_model(self) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_model( - self, - ) -> typing.Callable[ - [model_service.GetModelRequest], - typing.Union[model.Model, typing.Awaitable[model.Model]], - ]: + def get_model(self) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[ + model.Model, + typing.Awaitable[model.Model] + ]]: raise NotImplementedError() @property - def list_models( - self, - ) -> typing.Callable[ - [model_service.ListModelsRequest], - typing.Union[ - model_service.ListModelsResponse, - typing.Awaitable[model_service.ListModelsResponse], - ], - ]: + def list_models(self) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse] + ]]: raise NotImplementedError() @property - def update_model( - self, - ) -> typing.Callable[ - [model_service.UpdateModelRequest], - typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], - ]: + def update_model(self) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[ + gca_model.Model, + typing.Awaitable[gca_model.Model] + ]]: raise NotImplementedError() @property - def delete_model( - self, - ) -> typing.Callable[ - [model_service.DeleteModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_model(self) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def export_model( - self, - ) -> typing.Callable[ - [model_service.ExportModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def export_model(self) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_model_evaluation( - self, - ) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], - typing.Union[ - model_evaluation.ModelEvaluation, - typing.Awaitable[model_evaluation.ModelEvaluation], - ], - ]: + def get_model_evaluation(self) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation] + ]]: raise NotImplementedError() @property - def list_model_evaluations( - self, - ) -> typing.Callable[ - [model_service.ListModelEvaluationsRequest], - typing.Union[ - model_service.ListModelEvaluationsResponse, - typing.Awaitable[model_service.ListModelEvaluationsResponse], - ], - ]: + def list_model_evaluations(self) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse] + ]]: raise NotImplementedError() @property - def get_model_evaluation_slice( - self, - ) -> typing.Callable[ - [model_service.GetModelEvaluationSliceRequest], - typing.Union[ - model_evaluation_slice.ModelEvaluationSlice, - typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], - ], - ]: + def get_model_evaluation_slice(self) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice] + ]]: raise NotImplementedError() @property - def list_model_evaluation_slices( - self, - ) -> typing.Callable[ - [model_service.ListModelEvaluationSlicesRequest], - typing.Union[ - model_service.ListModelEvaluationSlicesResponse, - typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], - ], - ]: + def list_model_evaluation_slices(self) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse] + ]]: raise NotImplementedError() -__all__ = ("ModelServiceTransport",) +__all__ = ( + 'ModelServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index b401612b1c..85db2fddd7 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -49,24 +49,21 @@ class ModelServiceGrpcTransport(ModelServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -112,7 +109,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -120,70 +120,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -191,32 +171,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -246,12 +214,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -263,15 +232,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def upload_model( - self, - ) -> Callable[[model_service.UploadModelRequest], operations.Operation]: + def upload_model(self) -> Callable[ + [model_service.UploadModelRequest], + operations.Operation]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -286,16 +257,18 @@ def upload_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "upload_model" not in self._stubs: - self._stubs["upload_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", + if 'upload_model' not in self._stubs: + self._stubs['upload_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["upload_model"] + return self._stubs['upload_model'] @property - def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + model.Model]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -310,18 +283,18 @@ def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model" not in self._stubs: - self._stubs["get_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs["get_model"] + return self._stubs['get_model'] @property - def list_models( - self, - ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + model_service.ListModelsResponse]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -336,18 +309,18 @@ def list_models( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_models" not in self._stubs: - self._stubs["list_models"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs["list_models"] + return self._stubs['list_models'] @property - def update_model( - self, - ) -> Callable[[model_service.UpdateModelRequest], gca_model.Model]: + def update_model(self) -> Callable[ + [model_service.UpdateModelRequest], + gca_model.Model]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -362,18 +335,18 @@ def update_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_model" not in self._stubs: - self._stubs["update_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", + if 'update_model' not in self._stubs: + self._stubs['update_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs["update_model"] + return self._stubs['update_model'] @property - def delete_model( - self, - ) -> Callable[[model_service.DeleteModelRequest], operations.Operation]: + def delete_model(self) -> Callable[ + [model_service.DeleteModelRequest], + operations.Operation]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -390,18 +363,18 @@ def delete_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_model" not in self._stubs: - self._stubs["delete_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", + if 'delete_model' not in self._stubs: + self._stubs['delete_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_model"] + return self._stubs['delete_model'] @property - def export_model( - self, - ) -> Callable[[model_service.ExportModelRequest], operations.Operation]: + def export_model(self) -> Callable[ + [model_service.ExportModelRequest], + operations.Operation]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -419,20 +392,18 @@ def export_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_model" not in self._stubs: - self._stubs["export_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", + if 'export_model' not in self._stubs: + self._stubs['export_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_model"] + return self._stubs['export_model'] @property - def get_model_evaluation( - self, - ) -> Callable[ - [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation - ]: + def get_model_evaluation(self) -> Callable[ + [model_service.GetModelEvaluationRequest], + model_evaluation.ModelEvaluation]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -447,21 +418,18 @@ def get_model_evaluation( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation" not in self._stubs: - self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", + if 'get_model_evaluation' not in self._stubs: + self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs["get_model_evaluation"] + return self._stubs['get_model_evaluation'] @property - def list_model_evaluations( - self, - ) -> Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse, - ]: + def list_model_evaluations(self) -> Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -476,21 +444,18 @@ def list_model_evaluations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluations" not in self._stubs: - self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", + if 'list_model_evaluations' not in self._stubs: + self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs["list_model_evaluations"] + return self._stubs['list_model_evaluations'] @property - def get_model_evaluation_slice( - self, - ) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice, - ]: + def get_model_evaluation_slice(self) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + model_evaluation_slice.ModelEvaluationSlice]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -505,21 +470,18 @@ def get_model_evaluation_slice( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation_slice" not in self._stubs: - self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", + if 'get_model_evaluation_slice' not in self._stubs: + self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs["get_model_evaluation_slice"] + return self._stubs['get_model_evaluation_slice'] @property - def list_model_evaluation_slices( - self, - ) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse, - ]: + def list_model_evaluation_slices(self) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -534,13 +496,15 @@ def list_model_evaluation_slices( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluation_slices" not in self._stubs: - self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", + if 'list_model_evaluation_slices' not in self._stubs: + self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs["list_model_evaluation_slices"] + return self._stubs['list_model_evaluation_slices'] -__all__ = ("ModelServiceGrpcTransport",) +__all__ = ( + 'ModelServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index d05bebeeec..bd8ae232f9 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import model @@ -56,18 +56,16 @@ class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -93,24 +91,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -145,10 +141,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -157,7 +153,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -165,70 +164,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -236,18 +215,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -276,9 +245,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def upload_model( - self, - ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: + def upload_model(self) -> Callable[ + [model_service.UploadModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -293,18 +262,18 @@ def upload_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "upload_model" not in self._stubs: - self._stubs["upload_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", + if 'upload_model' not in self._stubs: + self._stubs['upload_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["upload_model"] + return self._stubs['upload_model'] @property - def get_model( - self, - ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + Awaitable[model.Model]]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -319,20 +288,18 @@ def get_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model" not in self._stubs: - self._stubs["get_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs["get_model"] + return self._stubs['get_model'] @property - def list_models( - self, - ) -> Callable[ - [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] - ]: + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + Awaitable[model_service.ListModelsResponse]]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -347,18 +314,18 @@ def list_models( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_models" not in self._stubs: - self._stubs["list_models"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs["list_models"] + return self._stubs['list_models'] @property - def update_model( - self, - ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: + def update_model(self) -> Callable[ + [model_service.UpdateModelRequest], + Awaitable[gca_model.Model]]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -373,18 +340,18 @@ def update_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_model" not in self._stubs: - self._stubs["update_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", + if 'update_model' not in self._stubs: + self._stubs['update_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs["update_model"] + return self._stubs['update_model'] @property - def delete_model( - self, - ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: + def delete_model(self) -> Callable[ + [model_service.DeleteModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -401,18 +368,18 @@ def delete_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_model" not in self._stubs: - self._stubs["delete_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", + if 'delete_model' not in self._stubs: + self._stubs['delete_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_model"] + return self._stubs['delete_model'] @property - def export_model( - self, - ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: + def export_model(self) -> Callable[ + [model_service.ExportModelRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -430,21 +397,18 @@ def export_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_model" not in self._stubs: - self._stubs["export_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", + if 'export_model' not in self._stubs: + self._stubs['export_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_model"] + return self._stubs['export_model'] @property - def get_model_evaluation( - self, - ) -> Callable[ - [model_service.GetModelEvaluationRequest], - Awaitable[model_evaluation.ModelEvaluation], - ]: + def get_model_evaluation(self) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation]]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -459,21 +423,18 @@ def get_model_evaluation( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation" not in self._stubs: - self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", + if 'get_model_evaluation' not in self._stubs: + self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs["get_model_evaluation"] + return self._stubs['get_model_evaluation'] @property - def list_model_evaluations( - self, - ) -> Callable[ - [model_service.ListModelEvaluationsRequest], - Awaitable[model_service.ListModelEvaluationsResponse], - ]: + def list_model_evaluations(self) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse]]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -488,21 +449,18 @@ def list_model_evaluations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluations" not in self._stubs: - self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", + if 'list_model_evaluations' not in self._stubs: + self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs["list_model_evaluations"] + return self._stubs['list_model_evaluations'] @property - def get_model_evaluation_slice( - self, - ) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - Awaitable[model_evaluation_slice.ModelEvaluationSlice], - ]: + def get_model_evaluation_slice(self) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice]]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -517,21 +475,18 @@ def get_model_evaluation_slice( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation_slice" not in self._stubs: - self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", + if 'get_model_evaluation_slice' not in self._stubs: + self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs["get_model_evaluation_slice"] + return self._stubs['get_model_evaluation_slice'] @property - def list_model_evaluation_slices( - self, - ) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - Awaitable[model_service.ListModelEvaluationSlicesResponse], - ]: + def list_model_evaluation_slices(self) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse]]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -546,13 +501,15 @@ def list_model_evaluation_slices( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluation_slices" not in self._stubs: - self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", + if 'list_model_evaluation_slices' not in self._stubs: + self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs["list_model_evaluation_slices"] + return self._stubs['list_model_evaluation_slices'] -__all__ = ("ModelServiceGrpcAsyncIOTransport",) +__all__ = ( + 'ModelServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py index 7f02b47358..f7f4d9b9ac 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PipelineServiceAsyncClient __all__ = ( - "PipelineServiceClient", - "PipelineServiceAsyncClient", + 'PipelineServiceClient', + 'PipelineServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 063153700c..6235697be1 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -37,9 +37,7 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -63,38 +61,22 @@ class PipelineServiceAsyncClient: model_path = staticmethod(PipelineServiceClient.model_path) parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod( - PipelineServiceClient.parse_training_pipeline_path - ) + parse_training_pipeline_path = staticmethod(PipelineServiceClient.parse_training_pipeline_path) - common_billing_account_path = staticmethod( - PipelineServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - PipelineServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(PipelineServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(PipelineServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - PipelineServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(PipelineServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - PipelineServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - PipelineServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(PipelineServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(PipelineServiceClient.parse_common_organization_path) common_project_path = staticmethod(PipelineServiceClient.common_project_path) - parse_common_project_path = staticmethod( - PipelineServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(PipelineServiceClient.parse_common_project_path) common_location_path = staticmethod(PipelineServiceClient.common_location_path) - parse_common_location_path = staticmethod( - PipelineServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(PipelineServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -137,18 +119,14 @@ def transport(self) -> PipelineServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) - ) + get_transport_class = functools.partial(type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -187,18 +165,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_training_pipeline( - self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + async def create_training_pipeline(self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -243,10 +221,8 @@ async def create_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.CreateTrainingPipelineRequest(request) @@ -269,24 +245,30 @@ async def create_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_training_pipeline( - self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + async def get_training_pipeline(self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -323,10 +305,8 @@ async def get_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.GetTrainingPipelineRequest(request) @@ -347,24 +327,30 @@ async def get_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_training_pipelines( - self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesAsyncPager: + async def list_training_pipelines(self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: r"""Lists TrainingPipelines in a Location. Args: @@ -400,10 +386,8 @@ async def list_training_pipelines( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.ListTrainingPipelinesRequest(request) @@ -424,30 +408,39 @@ async def list_training_pipelines( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListTrainingPipelinesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_training_pipeline( - self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_training_pipeline(self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a TrainingPipeline. Args: @@ -493,10 +486,8 @@ async def delete_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.DeleteTrainingPipelineRequest(request) @@ -517,11 +508,18 @@ async def delete_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -534,15 +532,14 @@ async def delete_training_pipeline( # Done; return the response. return response - async def cancel_training_pipeline( - self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_training_pipeline(self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -581,10 +578,8 @@ async def cancel_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = pipeline_service.CancelTrainingPipelineRequest(request) @@ -605,23 +600,35 @@ async def cancel_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PipelineServiceAsyncClient",) +__all__ = ( + 'PipelineServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 4efc2064b5..07f1ac0444 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -41,9 +41,7 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -61,14 +59,13 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry['grpc'] = PipelineServiceGrpcTransport + _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry["grpc"] = PipelineServiceGrpcTransport - _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport - - def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -119,7 +116,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -154,8 +151,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PipelineServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -170,122 +168,99 @@ def transport(self) -> PipelineServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: + def parse_model_path(path: str) -> Dict[str,str]: """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path( - project: str, location: str, training_pipeline: str, - ) -> str: + def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str, str]: + def parse_training_pipeline_path(path: str) -> Dict[str,str]: """Parse a training_pipeline path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -329,9 +304,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -341,9 +314,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -355,9 +326,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -369,10 +338,8 @@ def __init__( if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -391,16 +358,15 @@ def __init__( client_info=client_info, ) - def create_training_pipeline( - self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline(self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -445,10 +411,8 @@ def create_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CreateTrainingPipelineRequest. @@ -472,24 +436,30 @@ def create_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_training_pipeline( - self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline(self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -526,10 +496,8 @@ def get_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.GetTrainingPipelineRequest. @@ -551,24 +519,30 @@ def get_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_training_pipelines( - self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines(self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -604,10 +578,8 @@ def list_training_pipelines( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.ListTrainingPipelinesRequest. @@ -629,30 +601,39 @@ def list_training_pipelines( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline( - self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_training_pipeline(self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -698,10 +679,8 @@ def delete_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.DeleteTrainingPipelineRequest. @@ -723,11 +702,18 @@ def delete_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -740,15 +726,14 @@ def delete_training_pipeline( # Done; return the response. return response - def cancel_training_pipeline( - self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline(self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -787,10 +772,8 @@ def cancel_training_pipeline( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CancelTrainingPipelineRequest. @@ -812,23 +795,35 @@ def cancel_training_pipeline( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PipelineServiceClient",) +__all__ = ( + 'PipelineServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index db2b4dd3a1..6de70ee1f1 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -47,15 +38,12 @@ class ListTrainingPipelinesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: yield from page.training_pipelines def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListTrainingPipelinesAsyncPager: @@ -109,17 +97,12 @@ class ListTrainingPipelinesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] - ], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[pipeline_service.ListTrainingPipelinesResponse]], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -141,9 +124,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + async def pages(self) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -159,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py index 9d4610087a..f289718f83 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] -_transport_registry["grpc"] = PipelineServiceGrpcTransport -_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = PipelineServiceGrpcTransport +_transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport __all__ = ( - "PipelineServiceTransport", - "PipelineServiceGrpcTransport", - "PipelineServiceGrpcAsyncIOTransport", + 'PipelineServiceTransport', + 'PipelineServiceGrpcTransport', + 'PipelineServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 41123b8615..30070650b2 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -21,16 +21,14 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -38,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -76,40 +74,38 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -138,6 +134,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), + } @property @@ -146,58 +143,51 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - typing.Union[ - gca_training_pipeline.TrainingPipeline, - typing.Awaitable[gca_training_pipeline.TrainingPipeline], - ], - ]: + def create_training_pipeline(self) -> typing.Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline] + ]]: raise NotImplementedError() @property - def get_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.GetTrainingPipelineRequest], - typing.Union[ - training_pipeline.TrainingPipeline, - typing.Awaitable[training_pipeline.TrainingPipeline], - ], - ]: + def get_training_pipeline(self) -> typing.Callable[ + [pipeline_service.GetTrainingPipelineRequest], + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline] + ]]: raise NotImplementedError() @property - def list_training_pipelines( - self, - ) -> typing.Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - typing.Union[ - pipeline_service.ListTrainingPipelinesResponse, - typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], - ], - ]: + def list_training_pipelines(self) -> typing.Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ]]: raise NotImplementedError() @property - def delete_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_training_pipeline(self) -> typing.Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def cancel_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def cancel_training_pipeline(self) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() -__all__ = ("PipelineServiceTransport",) +__all__ = ( + 'PipelineServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 83383d9e87..9c024143ef 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -18,20 +18,18 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -50,24 +48,21 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -113,7 +108,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -121,70 +119,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -192,32 +170,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -247,12 +213,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -264,18 +231,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline, - ]: + def create_training_pipeline(self) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + gca_training_pipeline.TrainingPipeline]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -291,21 +257,18 @@ def create_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_training_pipeline" not in self._stubs: - self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", + if 'create_training_pipeline' not in self._stubs: + self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["create_training_pipeline"] + return self._stubs['create_training_pipeline'] @property - def get_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline, - ]: + def get_training_pipeline(self) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + training_pipeline.TrainingPipeline]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -320,21 +283,18 @@ def get_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_training_pipeline" not in self._stubs: - self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", + if 'get_training_pipeline' not in self._stubs: + self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["get_training_pipeline"] + return self._stubs['get_training_pipeline'] @property - def list_training_pipelines( - self, - ) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse, - ]: + def list_training_pipelines(self) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -349,20 +309,18 @@ def list_training_pipelines( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_training_pipelines" not in self._stubs: - self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", + if 'list_training_pipelines' not in self._stubs: + self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs["list_training_pipelines"] + return self._stubs['list_training_pipelines'] @property - def delete_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation - ]: + def delete_training_pipeline(self) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + operations.Operation]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -377,18 +335,18 @@ def delete_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_training_pipeline" not in self._stubs: - self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", + if 'delete_training_pipeline' not in self._stubs: + self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_training_pipeline"] + return self._stubs['delete_training_pipeline'] @property - def cancel_training_pipeline( - self, - ) -> Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: + def cancel_training_pipeline(self) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + empty.Empty]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -415,13 +373,15 @@ def cancel_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_training_pipeline" not in self._stubs: - self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", + if 'cancel_training_pipeline' not in self._stubs: + self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_training_pipeline"] + return self._stubs['cancel_training_pipeline'] -__all__ = ("PipelineServiceGrpcTransport",) +__all__ = ( + 'PipelineServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index 76f21faf50..53bd371d65 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -18,21 +18,19 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -57,18 +55,16 @@ class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -94,24 +90,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -146,10 +140,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -158,7 +152,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -166,70 +163,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -237,18 +214,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -277,12 +244,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - Awaitable[gca_training_pipeline.TrainingPipeline], - ]: + def create_training_pipeline(self) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline]]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -298,21 +262,18 @@ def create_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_training_pipeline" not in self._stubs: - self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", + if 'create_training_pipeline' not in self._stubs: + self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["create_training_pipeline"] + return self._stubs['create_training_pipeline'] @property - def get_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - Awaitable[training_pipeline.TrainingPipeline], - ]: + def get_training_pipeline(self) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline]]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -327,21 +288,18 @@ def get_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_training_pipeline" not in self._stubs: - self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", + if 'get_training_pipeline' not in self._stubs: + self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["get_training_pipeline"] + return self._stubs['get_training_pipeline'] @property - def list_training_pipelines( - self, - ) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - Awaitable[pipeline_service.ListTrainingPipelinesResponse], - ]: + def list_training_pipelines(self) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse]]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -356,21 +314,18 @@ def list_training_pipelines( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_training_pipelines" not in self._stubs: - self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", + if 'list_training_pipelines' not in self._stubs: + self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs["list_training_pipelines"] + return self._stubs['list_training_pipelines'] @property - def delete_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - Awaitable[operations.Operation], - ]: + def delete_training_pipeline(self) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -385,20 +340,18 @@ def delete_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_training_pipeline" not in self._stubs: - self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", + if 'delete_training_pipeline' not in self._stubs: + self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_training_pipeline"] + return self._stubs['delete_training_pipeline'] @property - def cancel_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] - ]: + def cancel_training_pipeline(self) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -425,13 +378,15 @@ def cancel_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_training_pipeline" not in self._stubs: - self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", + if 'cancel_training_pipeline' not in self._stubs: + self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_training_pipeline"] + return self._stubs['cancel_training_pipeline'] -__all__ = ("PipelineServiceGrpcAsyncIOTransport",) +__all__ = ( + 'PipelineServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py index 0c847693e0..d4047c335d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PredictionServiceAsyncClient __all__ = ( - "PredictionServiceClient", - "PredictionServiceAsyncClient", + 'PredictionServiceClient', + 'PredictionServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index 4d69a6635f..64b514608c 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -48,34 +48,20 @@ class PredictionServiceAsyncClient: endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) - common_billing_account_path = staticmethod( - PredictionServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - PredictionServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(PredictionServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(PredictionServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - PredictionServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(PredictionServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - PredictionServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - PredictionServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(PredictionServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(PredictionServiceClient.parse_common_organization_path) common_project_path = staticmethod(PredictionServiceClient.common_project_path) - parse_common_project_path = staticmethod( - PredictionServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(PredictionServiceClient.parse_common_project_path) common_location_path = staticmethod(PredictionServiceClient.common_location_path) - parse_common_location_path = staticmethod( - PredictionServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(PredictionServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -118,18 +104,14 @@ def transport(self) -> PredictionServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) - ) + get_transport_class = functools.partial(type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -168,19 +150,19 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def predict( - self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + async def predict(self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -240,10 +222,8 @@ async def predict( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = prediction_service.PredictRequest(request) @@ -269,27 +249,33 @@ async def predict( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def explain( - self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: + async def explain(self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. If @@ -368,10 +354,8 @@ async def explain( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = prediction_service.ExplainRequest(request) @@ -399,24 +383,38 @@ async def explain( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PredictionServiceAsyncClient",) +__all__ = ( + 'PredictionServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 042307eca1..097cf3d0fe 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -48,16 +48,13 @@ class PredictionServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry['grpc'] = PredictionServiceGrpcTransport + _transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[PredictionServiceTransport]] - _transport_registry["grpc"] = PredictionServiceGrpcTransport - _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport - - def get_transport_class( - cls, label: str = None, - ) -> Type[PredictionServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[PredictionServiceTransport]: """Return an appropriate transport class. Args: @@ -108,7 +105,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -143,8 +140,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PredictionServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -159,88 +157,77 @@ def transport(self) -> PredictionServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: + def parse_endpoint_path(path: str) -> Dict[str,str]: """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -284,9 +271,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -296,9 +281,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -310,9 +293,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -324,10 +305,8 @@ def __init__( if isinstance(transport, PredictionServiceTransport): # transport is a PredictionServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -346,17 +325,16 @@ def __init__( client_info=client_info, ) - def predict( - self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + def predict(self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -416,10 +394,8 @@ def predict( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a prediction_service.PredictRequest. @@ -445,27 +421,33 @@ def predict( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def explain( - self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: + def explain(self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. If @@ -544,10 +526,8 @@ def explain( # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a prediction_service.ExplainRequest. @@ -575,24 +555,38 @@ def explain( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PredictionServiceClient",) +__all__ = ( + 'PredictionServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py index 9ec1369a05..15b5acb198 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] -_transport_registry["grpc"] = PredictionServiceGrpcTransport -_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = PredictionServiceGrpcTransport +_transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport __all__ = ( - "PredictionServiceTransport", - "PredictionServiceGrpcTransport", - "PredictionServiceGrpcAsyncIOTransport", + 'PredictionServiceTransport', + 'PredictionServiceGrpcTransport', + 'PredictionServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 0c82f7d83c..d391018e2c 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore @@ -31,29 +31,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -69,74 +69,73 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.predict: gapic_v1.method.wrap_method( - self.predict, default_timeout=5.0, client_info=client_info, + self.predict, + default_timeout=5.0, + client_info=client_info, ), self.explain: gapic_v1.method.wrap_method( - self.explain, default_timeout=5.0, client_info=client_info, + self.explain, + default_timeout=5.0, + client_info=client_info, ), + } @property - def predict( - self, - ) -> typing.Callable[ - [prediction_service.PredictRequest], - typing.Union[ - prediction_service.PredictResponse, - typing.Awaitable[prediction_service.PredictResponse], - ], - ]: + def predict(self) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse] + ]]: raise NotImplementedError() @property - def explain( - self, - ) -> typing.Callable[ - [prediction_service.ExplainRequest], - typing.Union[ - prediction_service.ExplainResponse, - typing.Awaitable[prediction_service.ExplainResponse], - ], - ]: + def explain(self) -> typing.Callable[ + [prediction_service.ExplainRequest], + typing.Union[ + prediction_service.ExplainResponse, + typing.Awaitable[prediction_service.ExplainResponse] + ]]: raise NotImplementedError() -__all__ = ("PredictionServiceTransport",) +__all__ = ( + 'PredictionServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index f3b9be0c3d..ae5dfad093 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -18,10 +18,10 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -43,24 +43,21 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -106,7 +103,9 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -114,70 +113,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -185,31 +164,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -239,20 +207,19 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property - def predict( - self, - ) -> Callable[ - [prediction_service.PredictRequest], prediction_service.PredictResponse - ]: + def predict(self) -> Callable[ + [prediction_service.PredictRequest], + prediction_service.PredictResponse]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -267,20 +234,18 @@ def predict( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "predict" not in self._stubs: - self._stubs["predict"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", + if 'predict' not in self._stubs: + self._stubs['predict'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs["predict"] + return self._stubs['predict'] @property - def explain( - self, - ) -> Callable[ - [prediction_service.ExplainRequest], prediction_service.ExplainResponse - ]: + def explain(self) -> Callable[ + [prediction_service.ExplainRequest], + prediction_service.ExplainResponse]: r"""Return a callable for the explain method over gRPC. Perform an online explanation. @@ -306,13 +271,15 @@ def explain( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "explain" not in self._stubs: - self._stubs["explain"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", + if 'explain' not in self._stubs: + self._stubs['explain'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', request_serializer=prediction_service.ExplainRequest.serialize, response_deserializer=prediction_service.ExplainResponse.deserialize, ) - return self._stubs["explain"] + return self._stubs['explain'] -__all__ = ("PredictionServiceGrpcTransport",) +__all__ = ( + 'PredictionServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index e1493acc9c..69fbb7edeb 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -18,13 +18,13 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -50,18 +50,16 @@ class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -87,24 +85,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -139,10 +135,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -151,7 +147,9 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -159,70 +157,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -230,17 +208,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -253,12 +222,9 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def predict( - self, - ) -> Callable[ - [prediction_service.PredictRequest], - Awaitable[prediction_service.PredictResponse], - ]: + def predict(self) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse]]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -273,21 +239,18 @@ def predict( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "predict" not in self._stubs: - self._stubs["predict"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", + if 'predict' not in self._stubs: + self._stubs['predict'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs["predict"] + return self._stubs['predict'] @property - def explain( - self, - ) -> Callable[ - [prediction_service.ExplainRequest], - Awaitable[prediction_service.ExplainResponse], - ]: + def explain(self) -> Callable[ + [prediction_service.ExplainRequest], + Awaitable[prediction_service.ExplainResponse]]: r"""Return a callable for the explain method over gRPC. Perform an online explanation. @@ -313,13 +276,15 @@ def explain( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "explain" not in self._stubs: - self._stubs["explain"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", + if 'explain' not in self._stubs: + self._stubs['explain'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', request_serializer=prediction_service.ExplainRequest.serialize, response_deserializer=prediction_service.ExplainResponse.deserialize, ) - return self._stubs["explain"] + return self._stubs['explain'] -__all__ = ("PredictionServiceGrpcAsyncIOTransport",) +__all__ = ( + 'PredictionServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py index 49e9cdf0a0..e4247d7758 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import SpecialistPoolServiceAsyncClient __all__ = ( - "SpecialistPoolServiceClient", - "SpecialistPoolServiceAsyncClient", + 'SpecialistPoolServiceClient', + 'SpecialistPoolServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index 6907135b53..a6de6886e7 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -57,43 +57,23 @@ class SpecialistPoolServiceAsyncClient: DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT - specialist_pool_path = staticmethod( - SpecialistPoolServiceClient.specialist_pool_path - ) - parse_specialist_pool_path = staticmethod( - SpecialistPoolServiceClient.parse_specialist_pool_path - ) + specialist_pool_path = staticmethod(SpecialistPoolServiceClient.specialist_pool_path) + parse_specialist_pool_path = staticmethod(SpecialistPoolServiceClient.parse_specialist_pool_path) - common_billing_account_path = staticmethod( - SpecialistPoolServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - SpecialistPoolServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(SpecialistPoolServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(SpecialistPoolServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - SpecialistPoolServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(SpecialistPoolServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - SpecialistPoolServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - SpecialistPoolServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(SpecialistPoolServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(SpecialistPoolServiceClient.parse_common_organization_path) common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) - parse_common_project_path = staticmethod( - SpecialistPoolServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(SpecialistPoolServiceClient.parse_common_project_path) - common_location_path = staticmethod( - SpecialistPoolServiceClient.common_location_path - ) - parse_common_location_path = staticmethod( - SpecialistPoolServiceClient.parse_common_location_path - ) + common_location_path = staticmethod(SpecialistPoolServiceClient.common_location_path) + parse_common_location_path = staticmethod(SpecialistPoolServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -136,19 +116,14 @@ def transport(self) -> SpecialistPoolServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(SpecialistPoolServiceClient).get_transport_class, - type(SpecialistPoolServiceClient), - ) + get_transport_class = functools.partial(type(SpecialistPoolServiceClient).get_transport_class, type(SpecialistPoolServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -187,18 +162,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_specialist_pool( - self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_specialist_pool(self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a SpecialistPool. Args: @@ -246,10 +221,8 @@ async def create_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.CreateSpecialistPoolRequest(request) @@ -272,11 +245,18 @@ async def create_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -289,15 +269,14 @@ async def create_specialist_pool( # Done; return the response. return response - async def get_specialist_pool( - self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + async def get_specialist_pool(self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -339,10 +318,8 @@ async def get_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.GetSpecialistPoolRequest(request) @@ -363,24 +340,30 @@ async def get_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_specialist_pools( - self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsAsyncPager: + async def list_specialist_pools(self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: r"""Lists SpecialistPools in a Location. Args: @@ -416,10 +399,8 @@ async def list_specialist_pools( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.ListSpecialistPoolsRequest(request) @@ -440,30 +421,39 @@ async def list_specialist_pools( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListSpecialistPoolsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_specialist_pool( - self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_specialist_pool(self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -510,10 +500,8 @@ async def delete_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.DeleteSpecialistPoolRequest(request) @@ -534,11 +522,18 @@ async def delete_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -551,16 +546,15 @@ async def delete_specialist_pool( # Done; return the response. return response - async def update_specialist_pool( - self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_specialist_pool(self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates a SpecialistPool. Args: @@ -607,10 +601,8 @@ async def update_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = specialist_pool_service.UpdateSpecialistPoolRequest(request) @@ -633,13 +625,18 @@ async def update_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("specialist_pool.name", request.specialist_pool.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('specialist_pool.name', request.specialist_pool.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -653,14 +650,21 @@ async def update_specialist_pool( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("SpecialistPoolServiceAsyncClient",) +__all__ = ( + 'SpecialistPoolServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index cde21b3720..813d6413ff 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -54,16 +54,13 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport + _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport - _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport - - def get_transport_class( - cls, label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -120,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -155,8 +152,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: SpecialistPoolServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -171,88 +169,77 @@ def transport(self) -> SpecialistPoolServiceTransport: return self._transport @staticmethod - def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: + def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str, str]: + def parse_specialist_pool_path(path: str) -> Dict[str,str]: """Parse a specialist_pool path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -296,9 +283,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -308,9 +293,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -322,9 +305,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -336,10 +317,8 @@ def __init__( if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -358,16 +337,15 @@ def __init__( client_info=client_info, ) - def create_specialist_pool( - self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_specialist_pool(self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -415,10 +393,8 @@ def create_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.CreateSpecialistPoolRequest. @@ -442,11 +418,18 @@ def create_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -459,15 +442,14 @@ def create_specialist_pool( # Done; return the response. return response - def get_specialist_pool( - self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool(self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -509,10 +491,8 @@ def get_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.GetSpecialistPoolRequest. @@ -534,24 +514,30 @@ def get_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_specialist_pools( - self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools(self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -587,10 +573,8 @@ def list_specialist_pools( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.ListSpecialistPoolsRequest. @@ -612,30 +596,39 @@ def list_specialist_pools( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool( - self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_specialist_pool(self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -682,10 +675,8 @@ def delete_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.DeleteSpecialistPoolRequest. @@ -707,11 +698,18 @@ def delete_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -724,16 +722,15 @@ def delete_specialist_pool( # Done; return the response. return response - def update_specialist_pool( - self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def update_specialist_pool(self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -780,10 +777,8 @@ def update_specialist_pool( # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.UpdateSpecialistPoolRequest. @@ -807,13 +802,18 @@ def update_specialist_pool( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("specialist_pool.name", request.specialist_pool.name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('specialist_pool.name', request.specialist_pool.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -827,14 +827,21 @@ def update_specialist_pool( return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("SpecialistPoolServiceClient",) +__all__ = ( + 'SpecialistPoolServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index 976bcf55b8..6b5d115c82 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service @@ -47,15 +38,12 @@ class ListSpecialistPoolsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: yield from page.specialist_pools def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListSpecialistPoolsAsyncPager: @@ -109,17 +97,12 @@ class ListSpecialistPoolsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] - ], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -141,9 +124,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages( - self, - ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + async def pages(self) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -159,4 +140,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py index 1bb2fbf22a..80de7b209f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py @@ -24,14 +24,12 @@ # Compile a registry of transports. -_transport_registry = ( - OrderedDict() -) # type: Dict[str, Type[SpecialistPoolServiceTransport]] -_transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport -_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport +_transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] +_transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport +_transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( - "SpecialistPoolServiceTransport", - "SpecialistPoolServiceGrpcTransport", - "SpecialistPoolServiceGrpcAsyncIOTransport", + 'SpecialistPoolServiceTransport', + 'SpecialistPoolServiceGrpcTransport', + 'SpecialistPoolServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index f1af058030..43c7e87f16 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,29 +34,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -72,40 +72,38 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -115,7 +113,9 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_specialist_pool: gapic_v1.method.wrap_method( - self.get_specialist_pool, default_timeout=5.0, client_info=client_info, + self.get_specialist_pool, + default_timeout=5.0, + client_info=client_info, ), self.list_specialist_pools: gapic_v1.method.wrap_method( self.list_specialist_pools, @@ -132,6 +132,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), + } @property @@ -140,55 +141,51 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def create_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def get_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - typing.Union[ - specialist_pool.SpecialistPool, - typing.Awaitable[specialist_pool.SpecialistPool], - ], - ]: + def get_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool] + ]]: raise NotImplementedError() @property - def list_specialist_pools( - self, - ) -> typing.Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - typing.Union[ - specialist_pool_service.ListSpecialistPoolsResponse, - typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], - ], - ]: + def list_specialist_pools(self) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ]]: raise NotImplementedError() @property - def delete_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def delete_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def update_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def update_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() -__all__ = ("SpecialistPoolServiceTransport",) +__all__ = ( + 'SpecialistPoolServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index dbc31f0c7e..256765e7eb 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,24 +51,21 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -114,7 +111,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -122,70 +122,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -193,32 +173,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -248,12 +216,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -265,17 +234,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation - ]: + def create_specialist_pool(self) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -290,21 +259,18 @@ def create_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_specialist_pool" not in self._stubs: - self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", + if 'create_specialist_pool' not in self._stubs: + self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_specialist_pool"] + return self._stubs['create_specialist_pool'] @property - def get_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool, - ]: + def get_specialist_pool(self) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + specialist_pool.SpecialistPool]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -319,21 +285,18 @@ def get_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_specialist_pool" not in self._stubs: - self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", + if 'get_specialist_pool' not in self._stubs: + self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs["get_specialist_pool"] + return self._stubs['get_specialist_pool'] @property - def list_specialist_pools( - self, - ) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse, - ]: + def list_specialist_pools(self) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -348,20 +311,18 @@ def list_specialist_pools( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_specialist_pools" not in self._stubs: - self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", + if 'list_specialist_pools' not in self._stubs: + self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs["list_specialist_pools"] + return self._stubs['list_specialist_pools'] @property - def delete_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation - ]: + def delete_specialist_pool(self) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -377,20 +338,18 @@ def delete_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_specialist_pool" not in self._stubs: - self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", + if 'delete_specialist_pool' not in self._stubs: + self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_specialist_pool"] + return self._stubs['delete_specialist_pool'] @property - def update_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation - ]: + def update_specialist_pool(self) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -405,13 +364,15 @@ def update_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_specialist_pool" not in self._stubs: - self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", + if 'update_specialist_pool' not in self._stubs: + self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["update_specialist_pool"] + return self._stubs['update_specialist_pool'] -__all__ = ("SpecialistPoolServiceGrpcTransport",) +__all__ = ( + 'SpecialistPoolServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index a71d380b5b..8bf8ea2c2e 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import specialist_pool @@ -58,18 +58,16 @@ class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -95,24 +93,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -147,10 +143,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -159,7 +155,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -167,70 +166,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -238,18 +217,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -278,12 +247,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: + def create_specialist_pool(self) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -298,21 +264,18 @@ def create_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_specialist_pool" not in self._stubs: - self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", + if 'create_specialist_pool' not in self._stubs: + self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_specialist_pool"] + return self._stubs['create_specialist_pool'] @property - def get_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - Awaitable[specialist_pool.SpecialistPool], - ]: + def get_specialist_pool(self) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool]]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -327,21 +290,18 @@ def get_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_specialist_pool" not in self._stubs: - self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", + if 'get_specialist_pool' not in self._stubs: + self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs["get_specialist_pool"] + return self._stubs['get_specialist_pool'] @property - def list_specialist_pools( - self, - ) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], - ]: + def list_specialist_pools(self) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -356,21 +316,18 @@ def list_specialist_pools( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_specialist_pools" not in self._stubs: - self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", + if 'list_specialist_pools' not in self._stubs: + self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs["list_specialist_pools"] + return self._stubs['list_specialist_pools'] @property - def delete_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: + def delete_specialist_pool(self) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -386,21 +343,18 @@ def delete_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_specialist_pool" not in self._stubs: - self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", + if 'delete_specialist_pool' not in self._stubs: + self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_specialist_pool"] + return self._stubs['delete_specialist_pool'] @property - def update_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: + def update_specialist_pool(self) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -415,13 +369,15 @@ def update_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_specialist_pool" not in self._stubs: - self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", + if 'update_specialist_pool' not in self._stubs: + self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["update_specialist_pool"] + return self._stubs['update_specialist_pool'] -__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) +__all__ = ( + 'SpecialistPoolServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py index 5c312868f1..4c173a843c 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import VizierServiceAsyncClient __all__ = ( - "VizierServiceClient", - "VizierServiceAsyncClient", + 'VizierServiceClient', + 'VizierServiceAsyncClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py index 4bd90a79cd..4844bd0528 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,34 +60,20 @@ class VizierServiceAsyncClient: trial_path = staticmethod(VizierServiceClient.trial_path) parse_trial_path = staticmethod(VizierServiceClient.parse_trial_path) - common_billing_account_path = staticmethod( - VizierServiceClient.common_billing_account_path - ) - parse_common_billing_account_path = staticmethod( - VizierServiceClient.parse_common_billing_account_path - ) + common_billing_account_path = staticmethod(VizierServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(VizierServiceClient.parse_common_billing_account_path) common_folder_path = staticmethod(VizierServiceClient.common_folder_path) - parse_common_folder_path = staticmethod( - VizierServiceClient.parse_common_folder_path - ) + parse_common_folder_path = staticmethod(VizierServiceClient.parse_common_folder_path) - common_organization_path = staticmethod( - VizierServiceClient.common_organization_path - ) - parse_common_organization_path = staticmethod( - VizierServiceClient.parse_common_organization_path - ) + common_organization_path = staticmethod(VizierServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(VizierServiceClient.parse_common_organization_path) common_project_path = staticmethod(VizierServiceClient.common_project_path) - parse_common_project_path = staticmethod( - VizierServiceClient.parse_common_project_path - ) + parse_common_project_path = staticmethod(VizierServiceClient.parse_common_project_path) common_location_path = staticmethod(VizierServiceClient.common_location_path) - parse_common_location_path = staticmethod( - VizierServiceClient.parse_common_location_path - ) + parse_common_location_path = staticmethod(VizierServiceClient.parse_common_location_path) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -130,18 +116,14 @@ def transport(self) -> VizierServiceTransport: """ return self._client.transport - get_transport_class = functools.partial( - type(VizierServiceClient).get_transport_class, type(VizierServiceClient) - ) + get_transport_class = functools.partial(type(VizierServiceClient).get_transport_class, type(VizierServiceClient)) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, VizierServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, VizierServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the vizier service client. Args: @@ -180,18 +162,18 @@ def __init__( transport=transport, client_options=client_options, client_info=client_info, + ) - async def create_study( - self, - request: vizier_service.CreateStudyRequest = None, - *, - parent: str = None, - study: gca_study.Study = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_study.Study: + async def create_study(self, + request: vizier_service.CreateStudyRequest = None, + *, + parent: str = None, + study: gca_study.Study = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_study.Study: r"""Creates a Study. A resource name will be generated after creation of the Study. @@ -230,10 +212,8 @@ async def create_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.CreateStudyRequest(request) @@ -256,24 +236,30 @@ async def create_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_study( - self, - request: vizier_service.GetStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + async def get_study(self, + request: vizier_service.GetStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Gets a Study by name. Args: @@ -303,10 +289,8 @@ async def get_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.GetStudyRequest(request) @@ -327,24 +311,30 @@ async def get_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_studies( - self, - request: vizier_service.ListStudiesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListStudiesAsyncPager: + async def list_studies(self, + request: vizier_service.ListStudiesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListStudiesAsyncPager: r"""Lists all the studies in a region for an associated project. @@ -381,10 +371,8 @@ async def list_studies( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.ListStudiesRequest(request) @@ -405,30 +393,39 @@ async def list_studies( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListStudiesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def delete_study( - self, - request: vizier_service.DeleteStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def delete_study(self, + request: vizier_service.DeleteStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Study. Args: @@ -455,10 +452,8 @@ async def delete_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.DeleteStudyRequest(request) @@ -479,23 +474,27 @@ async def delete_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - async def lookup_study( - self, - request: vizier_service.LookupStudyRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + async def lookup_study(self, + request: vizier_service.LookupStudyRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Looks a study up using the user-defined display_name field instead of the fully qualified resource name. @@ -527,10 +526,8 @@ async def lookup_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.LookupStudyRequest(request) @@ -551,23 +548,29 @@ async def lookup_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def suggest_trials( - self, - request: vizier_service.SuggestTrialsRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def suggest_trials(self, + request: vizier_service.SuggestTrialsRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Adds one or more Trials to a Study, with parameter values suggested by AI Platform Vizier. Returns a long-running operation associated with the generation of Trial suggestions. @@ -610,11 +613,18 @@ async def suggest_trials( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -627,16 +637,15 @@ async def suggest_trials( # Done; return the response. return response - async def create_trial( - self, - request: vizier_service.CreateTrialRequest = None, - *, - parent: str = None, - trial: study.Trial = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def create_trial(self, + request: vizier_service.CreateTrialRequest = None, + *, + parent: str = None, + trial: study.Trial = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a user provided Trial to a Study. Args: @@ -677,10 +686,8 @@ async def create_trial( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.CreateTrialRequest(request) @@ -703,24 +710,30 @@ async def create_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def get_trial( - self, - request: vizier_service.GetTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def get_trial(self, + request: vizier_service.GetTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Gets a Trial. Args: @@ -755,10 +768,8 @@ async def get_trial( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.GetTrialRequest(request) @@ -779,24 +790,30 @@ async def get_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_trials( - self, - request: vizier_service.ListTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrialsAsyncPager: + async def list_trials(self, + request: vizier_service.ListTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrialsAsyncPager: r"""Lists the Trials associated with a Study. Args: @@ -832,10 +849,8 @@ async def list_trials( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.ListTrialsRequest(request) @@ -856,29 +871,38 @@ async def list_trials( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListTrialsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - async def add_trial_measurement( - self, - request: vizier_service.AddTrialMeasurementRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def add_trial_measurement(self, + request: vizier_service.AddTrialMeasurementRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a measurement of the objective metrics to a Trial. This measurement is assumed to have been taken before the Trial is complete. @@ -918,25 +942,29 @@ async def add_trial_measurement( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("trial_name", request.trial_name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('trial_name', request.trial_name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def complete_trial( - self, - request: vizier_service.CompleteTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def complete_trial(self, + request: vizier_service.CompleteTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Marks a Trial as complete. Args: @@ -974,24 +1002,30 @@ async def complete_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def delete_trial( - self, - request: vizier_service.DeleteTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def delete_trial(self, + request: vizier_service.DeleteTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Trial. Args: @@ -1017,10 +1051,8 @@ async def delete_trial( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.DeleteTrialRequest(request) @@ -1041,22 +1073,26 @@ async def delete_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - async def check_trial_early_stopping_state( - self, - request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def check_trial_early_stopping_state(self, + request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Checks whether a Trial should stop or not. Returns a long-running operation. When the operation is successful, it will contain a @@ -1098,13 +1134,18 @@ async def check_trial_early_stopping_state( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("trial_name", request.trial_name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('trial_name', request.trial_name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1117,14 +1158,13 @@ async def check_trial_early_stopping_state( # Done; return the response. return response - async def stop_trial( - self, - request: vizier_service.StopTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def stop_trial(self, + request: vizier_service.StopTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Stops a Trial. Args: @@ -1162,24 +1202,30 @@ async def stop_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - async def list_optimal_trials( - self, - request: vizier_service.ListOptimalTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> vizier_service.ListOptimalTrialsResponse: + async def list_optimal_trials(self, + request: vizier_service.ListOptimalTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> vizier_service.ListOptimalTrialsResponse: r"""Lists the pareto-optimal Trials for multi-objective Study or the optimal Trials for single-objective Study. The definition of pareto-optimal can be checked in wiki page. @@ -1214,10 +1260,8 @@ async def list_optimal_trials( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') request = vizier_service.ListOptimalTrialsRequest(request) @@ -1238,24 +1282,38 @@ async def list_optimal_trials( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("VizierServiceAsyncClient",) +__all__ = ( + 'VizierServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py index 85e381323d..13587919b9 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -52,12 +52,13 @@ class VizierServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[VizierServiceTransport]] - _transport_registry["grpc"] = VizierServiceGrpcTransport - _transport_registry["grpc_asyncio"] = VizierServiceGrpcAsyncIOTransport + _transport_registry['grpc'] = VizierServiceGrpcTransport + _transport_registry['grpc_asyncio'] = VizierServiceGrpcAsyncIOTransport - def get_transport_class(cls, label: str = None,) -> Type[VizierServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[VizierServiceTransport]: """Return an appropriate transport class. Args: @@ -112,7 +113,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -147,8 +148,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: VizierServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -163,120 +165,99 @@ def transport(self) -> VizierServiceTransport: return self._transport @staticmethod - def custom_job_path(project: str, location: str, custom_job: str,) -> str: + def custom_job_path(project: str,location: str,custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str, str]: + def parse_custom_job_path(path: str) -> Dict[str,str]: """Parse a custom_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def study_path(project: str, location: str, study: str,) -> str: + def study_path(project: str,location: str,study: str,) -> str: """Return a fully-qualified study string.""" - return "projects/{project}/locations/{location}/studies/{study}".format( - project=project, location=location, study=study, - ) + return "projects/{project}/locations/{location}/studies/{study}".format(project=project, location=location, study=study, ) @staticmethod - def parse_study_path(path: str) -> Dict[str, str]: + def parse_study_path(path: str) -> Dict[str,str]: """Parse a study path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str, location: str, study: str, trial: str,) -> str: + def trial_path(project: str,location: str,study: str,trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( - project=project, location=location, study=study, trial=trial, - ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) @staticmethod - def parse_trial_path(path: str) -> Dict[str, str]: + def parse_trial_path(path: str) -> Dict[str,str]: """Parse a trial path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str,) -> str: + def common_billing_account_path(billing_account: str, ) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str, str]: + def parse_common_billing_account_path(path: str) -> Dict[str,str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str,) -> str: + def common_folder_path(folder: str, ) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder,) + return "folders/{folder}".format(folder=folder, ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str, str]: + def parse_common_folder_path(path: str) -> Dict[str,str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str,) -> str: + def common_organization_path(organization: str, ) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization,) + return "organizations/{organization}".format(organization=organization, ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str, str]: + def parse_common_organization_path(path: str) -> Dict[str,str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str,) -> str: + def common_project_path(project: str, ) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project,) + return "projects/{project}".format(project=project, ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str, str]: + def parse_common_project_path(path: str) -> Dict[str,str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str,) -> str: + def common_location_path(project: str, location: str, ) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + return "projects/{project}/locations/{location}".format(project=project, location=location, ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str, str]: + def parse_common_location_path(path: str) -> Dict[str,str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, VizierServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, VizierServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the vizier service client. Args: @@ -320,9 +301,7 @@ def __init__( client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) client_cert_source_func = None is_mtls = False @@ -332,9 +311,7 @@ def __init__( client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = ( - mtls.default_client_cert_source() if is_mtls else None - ) + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -346,9 +323,7 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -360,10 +335,8 @@ def __init__( if isinstance(transport, VizierServiceTransport): # transport is a VizierServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError( - "When providing a transport instance, " - "provide its credentials directly." - ) + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -382,16 +355,15 @@ def __init__( client_info=client_info, ) - def create_study( - self, - request: vizier_service.CreateStudyRequest = None, - *, - parent: str = None, - study: gca_study.Study = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_study.Study: + def create_study(self, + request: vizier_service.CreateStudyRequest = None, + *, + parent: str = None, + study: gca_study.Study = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_study.Study: r"""Creates a Study. A resource name will be generated after creation of the Study. @@ -430,10 +402,8 @@ def create_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.CreateStudyRequest. @@ -457,24 +427,30 @@ def create_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_study( - self, - request: vizier_service.GetStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + def get_study(self, + request: vizier_service.GetStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Gets a Study by name. Args: @@ -504,10 +480,8 @@ def get_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.GetStudyRequest. @@ -529,24 +503,30 @@ def get_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_studies( - self, - request: vizier_service.ListStudiesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListStudiesPager: + def list_studies(self, + request: vizier_service.ListStudiesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListStudiesPager: r"""Lists all the studies in a region for an associated project. @@ -583,10 +563,8 @@ def list_studies( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.ListStudiesRequest. @@ -608,30 +586,39 @@ def list_studies( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListStudiesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_study( - self, - request: vizier_service.DeleteStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def delete_study(self, + request: vizier_service.DeleteStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Study. Args: @@ -658,10 +645,8 @@ def delete_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.DeleteStudyRequest. @@ -683,23 +668,27 @@ def delete_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def lookup_study( - self, - request: vizier_service.LookupStudyRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + def lookup_study(self, + request: vizier_service.LookupStudyRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Looks a study up using the user-defined display_name field instead of the fully qualified resource name. @@ -731,10 +720,8 @@ def lookup_study( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.LookupStudyRequest. @@ -756,23 +743,29 @@ def lookup_study( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def suggest_trials( - self, - request: vizier_service.SuggestTrialsRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def suggest_trials(self, + request: vizier_service.SuggestTrialsRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Adds one or more Trials to a Study, with parameter values suggested by AI Platform Vizier. Returns a long-running operation associated with the generation of Trial suggestions. @@ -816,11 +809,18 @@ def suggest_trials( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation.from_gapic( @@ -833,16 +833,15 @@ def suggest_trials( # Done; return the response. return response - def create_trial( - self, - request: vizier_service.CreateTrialRequest = None, - *, - parent: str = None, - trial: study.Trial = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def create_trial(self, + request: vizier_service.CreateTrialRequest = None, + *, + parent: str = None, + trial: study.Trial = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a user provided Trial to a Study. Args: @@ -883,10 +882,8 @@ def create_trial( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.CreateTrialRequest. @@ -910,24 +907,30 @@ def create_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_trial( - self, - request: vizier_service.GetTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def get_trial(self, + request: vizier_service.GetTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Gets a Trial. Args: @@ -962,10 +965,8 @@ def get_trial( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.GetTrialRequest. @@ -987,24 +988,30 @@ def get_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_trials( - self, - request: vizier_service.ListTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrialsPager: + def list_trials(self, + request: vizier_service.ListTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrialsPager: r"""Lists the Trials associated with a Study. Args: @@ -1040,10 +1047,8 @@ def list_trials( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.ListTrialsRequest. @@ -1065,29 +1070,38 @@ def list_trials( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrialsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def add_trial_measurement( - self, - request: vizier_service.AddTrialMeasurementRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def add_trial_measurement(self, + request: vizier_service.AddTrialMeasurementRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a measurement of the objective metrics to a Trial. This measurement is assumed to have been taken before the Trial is complete. @@ -1128,25 +1142,29 @@ def add_trial_measurement( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("trial_name", request.trial_name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('trial_name', request.trial_name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def complete_trial( - self, - request: vizier_service.CompleteTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def complete_trial(self, + request: vizier_service.CompleteTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Marks a Trial as complete. Args: @@ -1185,24 +1203,30 @@ def complete_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def delete_trial( - self, - request: vizier_service.DeleteTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def delete_trial(self, + request: vizier_service.DeleteTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Trial. Args: @@ -1228,10 +1252,8 @@ def delete_trial( # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.DeleteTrialRequest. @@ -1253,22 +1275,26 @@ def delete_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def check_trial_early_stopping_state( - self, - request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def check_trial_early_stopping_state(self, + request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Checks whether a Trial should stop or not. Returns a long-running operation. When the operation is successful, it will contain a @@ -1306,20 +1332,23 @@ def check_trial_early_stopping_state( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.check_trial_early_stopping_state - ] + rpc = self._transport._wrapped_methods[self._transport.check_trial_early_stopping_state] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("trial_name", request.trial_name),) - ), + gapic_v1.routing_header.to_grpc_metadata(( + ('trial_name', request.trial_name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = operation.from_gapic( @@ -1332,14 +1361,13 @@ def check_trial_early_stopping_state( # Done; return the response. return response - def stop_trial( - self, - request: vizier_service.StopTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def stop_trial(self, + request: vizier_service.StopTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Stops a Trial. Args: @@ -1378,24 +1406,30 @@ def stop_trial( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_optimal_trials( - self, - request: vizier_service.ListOptimalTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> vizier_service.ListOptimalTrialsResponse: + def list_optimal_trials(self, + request: vizier_service.ListOptimalTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> vizier_service.ListOptimalTrialsResponse: r"""Lists the pareto-optimal Trials for multi-objective Study or the optimal Trials for single-objective Study. The definition of pareto-optimal can be checked in wiki page. @@ -1430,10 +1464,8 @@ def list_optimal_trials( # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') # Minor optimization to avoid making a copy if the user passes # in a vizier_service.ListOptimalTrialsRequest. @@ -1455,24 +1487,38 @@ def list_optimal_trials( # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response + + + + + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("VizierServiceClient",) +__all__ = ( + 'VizierServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py index c6e4fcdf63..5affed052e 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py @@ -15,16 +15,7 @@ # limitations under the License. # -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Iterable, - Sequence, - Tuple, - Optional, -) +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional from google.cloud.aiplatform_v1beta1.types import study from google.cloud.aiplatform_v1beta1.types import vizier_service @@ -47,15 +38,12 @@ class ListStudiesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., vizier_service.ListStudiesResponse], - request: vizier_service.ListStudiesRequest, - response: vizier_service.ListStudiesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., vizier_service.ListStudiesResponse], + request: vizier_service.ListStudiesRequest, + response: vizier_service.ListStudiesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -89,7 +77,7 @@ def __iter__(self) -> Iterable[study.Study]: yield from page.studies def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListStudiesAsyncPager: @@ -109,15 +97,12 @@ class ListStudiesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[vizier_service.ListStudiesResponse]], - request: vizier_service.ListStudiesRequest, - response: vizier_service.ListStudiesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[vizier_service.ListStudiesResponse]], + request: vizier_service.ListStudiesRequest, + response: vizier_service.ListStudiesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -155,7 +140,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListTrialsPager: @@ -175,15 +160,12 @@ class ListTrialsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., vizier_service.ListTrialsResponse], - request: vizier_service.ListTrialsRequest, - response: vizier_service.ListTrialsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., vizier_service.ListTrialsResponse], + request: vizier_service.ListTrialsRequest, + response: vizier_service.ListTrialsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -217,7 +199,7 @@ def __iter__(self) -> Iterable[study.Trial]: yield from page.trials def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListTrialsAsyncPager: @@ -237,15 +219,12 @@ class ListTrialsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[..., Awaitable[vizier_service.ListTrialsResponse]], - request: vizier_service.ListTrialsRequest, - response: vizier_service.ListTrialsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): + def __init__(self, + method: Callable[..., Awaitable[vizier_service.ListTrialsResponse]], + request: vizier_service.ListTrialsRequest, + response: vizier_service.ListTrialsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -283,4 +262,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py index 3ed347a603..de1a35ae04 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[VizierServiceTransport]] -_transport_registry["grpc"] = VizierServiceGrpcTransport -_transport_registry["grpc_asyncio"] = VizierServiceGrpcAsyncIOTransport +_transport_registry['grpc'] = VizierServiceGrpcTransport +_transport_registry['grpc_asyncio'] = VizierServiceGrpcAsyncIOTransport __all__ = ( - "VizierServiceTransport", - "VizierServiceGrpcTransport", - "VizierServiceGrpcAsyncIOTransport", + 'VizierServiceTransport', + 'VizierServiceGrpcTransport', + 'VizierServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py index 2fdfb4b13f..a6a5651b34 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - class VizierServiceTransport(abc.ABC): """Abstract transport class for VizierService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -74,69 +74,85 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_study: gapic_v1.method.wrap_method( - self.create_study, default_timeout=5.0, client_info=client_info, + self.create_study, + default_timeout=5.0, + client_info=client_info, ), self.get_study: gapic_v1.method.wrap_method( - self.get_study, default_timeout=5.0, client_info=client_info, + self.get_study, + default_timeout=5.0, + client_info=client_info, ), self.list_studies: gapic_v1.method.wrap_method( - self.list_studies, default_timeout=5.0, client_info=client_info, + self.list_studies, + default_timeout=5.0, + client_info=client_info, ), self.delete_study: gapic_v1.method.wrap_method( - self.delete_study, default_timeout=5.0, client_info=client_info, + self.delete_study, + default_timeout=5.0, + client_info=client_info, ), self.lookup_study: gapic_v1.method.wrap_method( - self.lookup_study, default_timeout=5.0, client_info=client_info, + self.lookup_study, + default_timeout=5.0, + client_info=client_info, ), self.suggest_trials: gapic_v1.method.wrap_method( - self.suggest_trials, default_timeout=5.0, client_info=client_info, + self.suggest_trials, + default_timeout=5.0, + client_info=client_info, ), self.create_trial: gapic_v1.method.wrap_method( - self.create_trial, default_timeout=5.0, client_info=client_info, + self.create_trial, + default_timeout=5.0, + client_info=client_info, ), self.get_trial: gapic_v1.method.wrap_method( - self.get_trial, default_timeout=5.0, client_info=client_info, + self.get_trial, + default_timeout=5.0, + client_info=client_info, ), self.list_trials: gapic_v1.method.wrap_method( - self.list_trials, default_timeout=5.0, client_info=client_info, + self.list_trials, + default_timeout=5.0, + client_info=client_info, ), self.add_trial_measurement: gapic_v1.method.wrap_method( self.add_trial_measurement, @@ -144,10 +160,14 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.complete_trial: gapic_v1.method.wrap_method( - self.complete_trial, default_timeout=5.0, client_info=client_info, + self.complete_trial, + default_timeout=5.0, + client_info=client_info, ), self.delete_trial: gapic_v1.method.wrap_method( - self.delete_trial, default_timeout=5.0, client_info=client_info, + self.delete_trial, + default_timeout=5.0, + client_info=client_info, ), self.check_trial_early_stopping_state: gapic_v1.method.wrap_method( self.check_trial_early_stopping_state, @@ -155,11 +175,16 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.stop_trial: gapic_v1.method.wrap_method( - self.stop_trial, default_timeout=5.0, client_info=client_info, + self.stop_trial, + default_timeout=5.0, + client_info=client_info, ), self.list_optimal_trials: gapic_v1.method.wrap_method( - self.list_optimal_trials, default_timeout=5.0, client_info=client_info, + self.list_optimal_trials, + default_timeout=5.0, + client_info=client_info, ), + } @property @@ -168,148 +193,141 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_study( - self, - ) -> typing.Callable[ - [vizier_service.CreateStudyRequest], - typing.Union[gca_study.Study, typing.Awaitable[gca_study.Study]], - ]: + def create_study(self) -> typing.Callable[ + [vizier_service.CreateStudyRequest], + typing.Union[ + gca_study.Study, + typing.Awaitable[gca_study.Study] + ]]: raise NotImplementedError() @property - def get_study( - self, - ) -> typing.Callable[ - [vizier_service.GetStudyRequest], - typing.Union[study.Study, typing.Awaitable[study.Study]], - ]: + def get_study(self) -> typing.Callable[ + [vizier_service.GetStudyRequest], + typing.Union[ + study.Study, + typing.Awaitable[study.Study] + ]]: raise NotImplementedError() @property - def list_studies( - self, - ) -> typing.Callable[ - [vizier_service.ListStudiesRequest], - typing.Union[ - vizier_service.ListStudiesResponse, - typing.Awaitable[vizier_service.ListStudiesResponse], - ], - ]: + def list_studies(self) -> typing.Callable[ + [vizier_service.ListStudiesRequest], + typing.Union[ + vizier_service.ListStudiesResponse, + typing.Awaitable[vizier_service.ListStudiesResponse] + ]]: raise NotImplementedError() @property - def delete_study( - self, - ) -> typing.Callable[ - [vizier_service.DeleteStudyRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def delete_study(self) -> typing.Callable[ + [vizier_service.DeleteStudyRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def lookup_study( - self, - ) -> typing.Callable[ - [vizier_service.LookupStudyRequest], - typing.Union[study.Study, typing.Awaitable[study.Study]], - ]: + def lookup_study(self) -> typing.Callable[ + [vizier_service.LookupStudyRequest], + typing.Union[ + study.Study, + typing.Awaitable[study.Study] + ]]: raise NotImplementedError() @property - def suggest_trials( - self, - ) -> typing.Callable[ - [vizier_service.SuggestTrialsRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def suggest_trials(self) -> typing.Callable[ + [vizier_service.SuggestTrialsRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def create_trial( - self, - ) -> typing.Callable[ - [vizier_service.CreateTrialRequest], - typing.Union[study.Trial, typing.Awaitable[study.Trial]], - ]: + def create_trial(self) -> typing.Callable[ + [vizier_service.CreateTrialRequest], + typing.Union[ + study.Trial, + typing.Awaitable[study.Trial] + ]]: raise NotImplementedError() @property - def get_trial( - self, - ) -> typing.Callable[ - [vizier_service.GetTrialRequest], - typing.Union[study.Trial, typing.Awaitable[study.Trial]], - ]: + def get_trial(self) -> typing.Callable[ + [vizier_service.GetTrialRequest], + typing.Union[ + study.Trial, + typing.Awaitable[study.Trial] + ]]: raise NotImplementedError() @property - def list_trials( - self, - ) -> typing.Callable[ - [vizier_service.ListTrialsRequest], - typing.Union[ - vizier_service.ListTrialsResponse, - typing.Awaitable[vizier_service.ListTrialsResponse], - ], - ]: + def list_trials(self) -> typing.Callable[ + [vizier_service.ListTrialsRequest], + typing.Union[ + vizier_service.ListTrialsResponse, + typing.Awaitable[vizier_service.ListTrialsResponse] + ]]: raise NotImplementedError() @property - def add_trial_measurement( - self, - ) -> typing.Callable[ - [vizier_service.AddTrialMeasurementRequest], - typing.Union[study.Trial, typing.Awaitable[study.Trial]], - ]: + def add_trial_measurement(self) -> typing.Callable[ + [vizier_service.AddTrialMeasurementRequest], + typing.Union[ + study.Trial, + typing.Awaitable[study.Trial] + ]]: raise NotImplementedError() @property - def complete_trial( - self, - ) -> typing.Callable[ - [vizier_service.CompleteTrialRequest], - typing.Union[study.Trial, typing.Awaitable[study.Trial]], - ]: + def complete_trial(self) -> typing.Callable[ + [vizier_service.CompleteTrialRequest], + typing.Union[ + study.Trial, + typing.Awaitable[study.Trial] + ]]: raise NotImplementedError() @property - def delete_trial( - self, - ) -> typing.Callable[ - [vizier_service.DeleteTrialRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: + def delete_trial(self) -> typing.Callable[ + [vizier_service.DeleteTrialRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: raise NotImplementedError() @property - def check_trial_early_stopping_state( - self, - ) -> typing.Callable[ - [vizier_service.CheckTrialEarlyStoppingStateRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: + def check_trial_early_stopping_state(self) -> typing.Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: raise NotImplementedError() @property - def stop_trial( - self, - ) -> typing.Callable[ - [vizier_service.StopTrialRequest], - typing.Union[study.Trial, typing.Awaitable[study.Trial]], - ]: + def stop_trial(self) -> typing.Callable[ + [vizier_service.StopTrialRequest], + typing.Union[ + study.Trial, + typing.Awaitable[study.Trial] + ]]: raise NotImplementedError() @property - def list_optimal_trials( - self, - ) -> typing.Callable[ - [vizier_service.ListOptimalTrialsRequest], - typing.Union[ - vizier_service.ListOptimalTrialsResponse, - typing.Awaitable[vizier_service.ListOptimalTrialsResponse], - ], - ]: + def list_optimal_trials(self) -> typing.Callable[ + [vizier_service.ListOptimalTrialsRequest], + typing.Union[ + vizier_service.ListOptimalTrialsResponse, + typing.Awaitable[vizier_service.ListOptimalTrialsResponse] + ]]: raise NotImplementedError() -__all__ = ("VizierServiceTransport",) +__all__ = ( + 'VizierServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py index 388d2746f5..a9e3db2e54 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,24 +51,21 @@ class VizierServiceGrpcTransport(VizierServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -114,7 +111,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -122,70 +122,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -193,32 +173,20 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - self._operations_client = None - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -248,12 +216,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -265,15 +234,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) # Return the client from cache. return self._operations_client @property - def create_study( - self, - ) -> Callable[[vizier_service.CreateStudyRequest], gca_study.Study]: + def create_study(self) -> Callable[ + [vizier_service.CreateStudyRequest], + gca_study.Study]: r"""Return a callable for the create study method over gRPC. Creates a Study. A resource name will be generated @@ -289,16 +260,18 @@ def create_study( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_study" not in self._stubs: - self._stubs["create_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy", + if 'create_study' not in self._stubs: + self._stubs['create_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy', request_serializer=vizier_service.CreateStudyRequest.serialize, response_deserializer=gca_study.Study.deserialize, ) - return self._stubs["create_study"] + return self._stubs['create_study'] @property - def get_study(self) -> Callable[[vizier_service.GetStudyRequest], study.Study]: + def get_study(self) -> Callable[ + [vizier_service.GetStudyRequest], + study.Study]: r"""Return a callable for the get study method over gRPC. Gets a Study by name. @@ -313,20 +286,18 @@ def get_study(self) -> Callable[[vizier_service.GetStudyRequest], study.Study]: # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_study" not in self._stubs: - self._stubs["get_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/GetStudy", + if 'get_study' not in self._stubs: + self._stubs['get_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/GetStudy', request_serializer=vizier_service.GetStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs["get_study"] + return self._stubs['get_study'] @property - def list_studies( - self, - ) -> Callable[ - [vizier_service.ListStudiesRequest], vizier_service.ListStudiesResponse - ]: + def list_studies(self) -> Callable[ + [vizier_service.ListStudiesRequest], + vizier_service.ListStudiesResponse]: r"""Return a callable for the list studies method over gRPC. Lists all the studies in a region for an associated @@ -342,18 +313,18 @@ def list_studies( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_studies" not in self._stubs: - self._stubs["list_studies"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/ListStudies", + if 'list_studies' not in self._stubs: + self._stubs['list_studies'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/ListStudies', request_serializer=vizier_service.ListStudiesRequest.serialize, response_deserializer=vizier_service.ListStudiesResponse.deserialize, ) - return self._stubs["list_studies"] + return self._stubs['list_studies'] @property - def delete_study( - self, - ) -> Callable[[vizier_service.DeleteStudyRequest], empty.Empty]: + def delete_study(self) -> Callable[ + [vizier_service.DeleteStudyRequest], + empty.Empty]: r"""Return a callable for the delete study method over gRPC. Deletes a Study. @@ -368,18 +339,18 @@ def delete_study( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_study" not in self._stubs: - self._stubs["delete_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy", + if 'delete_study' not in self._stubs: + self._stubs['delete_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy', request_serializer=vizier_service.DeleteStudyRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["delete_study"] + return self._stubs['delete_study'] @property - def lookup_study( - self, - ) -> Callable[[vizier_service.LookupStudyRequest], study.Study]: + def lookup_study(self) -> Callable[ + [vizier_service.LookupStudyRequest], + study.Study]: r"""Return a callable for the lookup study method over gRPC. Looks a study up using the user-defined display_name field @@ -395,18 +366,18 @@ def lookup_study( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "lookup_study" not in self._stubs: - self._stubs["lookup_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy", + if 'lookup_study' not in self._stubs: + self._stubs['lookup_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy', request_serializer=vizier_service.LookupStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs["lookup_study"] + return self._stubs['lookup_study'] @property - def suggest_trials( - self, - ) -> Callable[[vizier_service.SuggestTrialsRequest], operations.Operation]: + def suggest_trials(self) -> Callable[ + [vizier_service.SuggestTrialsRequest], + operations.Operation]: r"""Return a callable for the suggest trials method over gRPC. Adds one or more Trials to a Study, with parameter values @@ -425,18 +396,18 @@ def suggest_trials( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "suggest_trials" not in self._stubs: - self._stubs["suggest_trials"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials", + if 'suggest_trials' not in self._stubs: + self._stubs['suggest_trials'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials', request_serializer=vizier_service.SuggestTrialsRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["suggest_trials"] + return self._stubs['suggest_trials'] @property - def create_trial( - self, - ) -> Callable[[vizier_service.CreateTrialRequest], study.Trial]: + def create_trial(self) -> Callable[ + [vizier_service.CreateTrialRequest], + study.Trial]: r"""Return a callable for the create trial method over gRPC. Adds a user provided Trial to a Study. @@ -451,16 +422,18 @@ def create_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_trial" not in self._stubs: - self._stubs["create_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial", + if 'create_trial' not in self._stubs: + self._stubs['create_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial', request_serializer=vizier_service.CreateTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["create_trial"] + return self._stubs['create_trial'] @property - def get_trial(self) -> Callable[[vizier_service.GetTrialRequest], study.Trial]: + def get_trial(self) -> Callable[ + [vizier_service.GetTrialRequest], + study.Trial]: r"""Return a callable for the get trial method over gRPC. Gets a Trial. @@ -475,20 +448,18 @@ def get_trial(self) -> Callable[[vizier_service.GetTrialRequest], study.Trial]: # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_trial" not in self._stubs: - self._stubs["get_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/GetTrial", + if 'get_trial' not in self._stubs: + self._stubs['get_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/GetTrial', request_serializer=vizier_service.GetTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["get_trial"] + return self._stubs['get_trial'] @property - def list_trials( - self, - ) -> Callable[ - [vizier_service.ListTrialsRequest], vizier_service.ListTrialsResponse - ]: + def list_trials(self) -> Callable[ + [vizier_service.ListTrialsRequest], + vizier_service.ListTrialsResponse]: r"""Return a callable for the list trials method over gRPC. Lists the Trials associated with a Study. @@ -503,18 +474,18 @@ def list_trials( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_trials" not in self._stubs: - self._stubs["list_trials"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/ListTrials", + if 'list_trials' not in self._stubs: + self._stubs['list_trials'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/ListTrials', request_serializer=vizier_service.ListTrialsRequest.serialize, response_deserializer=vizier_service.ListTrialsResponse.deserialize, ) - return self._stubs["list_trials"] + return self._stubs['list_trials'] @property - def add_trial_measurement( - self, - ) -> Callable[[vizier_service.AddTrialMeasurementRequest], study.Trial]: + def add_trial_measurement(self) -> Callable[ + [vizier_service.AddTrialMeasurementRequest], + study.Trial]: r"""Return a callable for the add trial measurement method over gRPC. Adds a measurement of the objective metrics to a @@ -531,18 +502,18 @@ def add_trial_measurement( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "add_trial_measurement" not in self._stubs: - self._stubs["add_trial_measurement"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement", + if 'add_trial_measurement' not in self._stubs: + self._stubs['add_trial_measurement'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement', request_serializer=vizier_service.AddTrialMeasurementRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["add_trial_measurement"] + return self._stubs['add_trial_measurement'] @property - def complete_trial( - self, - ) -> Callable[[vizier_service.CompleteTrialRequest], study.Trial]: + def complete_trial(self) -> Callable[ + [vizier_service.CompleteTrialRequest], + study.Trial]: r"""Return a callable for the complete trial method over gRPC. Marks a Trial as complete. @@ -557,18 +528,18 @@ def complete_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "complete_trial" not in self._stubs: - self._stubs["complete_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial", + if 'complete_trial' not in self._stubs: + self._stubs['complete_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial', request_serializer=vizier_service.CompleteTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["complete_trial"] + return self._stubs['complete_trial'] @property - def delete_trial( - self, - ) -> Callable[[vizier_service.DeleteTrialRequest], empty.Empty]: + def delete_trial(self) -> Callable[ + [vizier_service.DeleteTrialRequest], + empty.Empty]: r"""Return a callable for the delete trial method over gRPC. Deletes a Trial. @@ -583,20 +554,18 @@ def delete_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_trial" not in self._stubs: - self._stubs["delete_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial", + if 'delete_trial' not in self._stubs: + self._stubs['delete_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial', request_serializer=vizier_service.DeleteTrialRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["delete_trial"] + return self._stubs['delete_trial'] @property - def check_trial_early_stopping_state( - self, - ) -> Callable[ - [vizier_service.CheckTrialEarlyStoppingStateRequest], operations.Operation - ]: + def check_trial_early_stopping_state(self) -> Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], + operations.Operation]: r"""Return a callable for the check trial early stopping state method over gRPC. @@ -615,18 +584,18 @@ def check_trial_early_stopping_state( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "check_trial_early_stopping_state" not in self._stubs: - self._stubs[ - "check_trial_early_stopping_state" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState", + if 'check_trial_early_stopping_state' not in self._stubs: + self._stubs['check_trial_early_stopping_state'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState', request_serializer=vizier_service.CheckTrialEarlyStoppingStateRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["check_trial_early_stopping_state"] + return self._stubs['check_trial_early_stopping_state'] @property - def stop_trial(self) -> Callable[[vizier_service.StopTrialRequest], study.Trial]: + def stop_trial(self) -> Callable[ + [vizier_service.StopTrialRequest], + study.Trial]: r"""Return a callable for the stop trial method over gRPC. Stops a Trial. @@ -641,21 +610,18 @@ def stop_trial(self) -> Callable[[vizier_service.StopTrialRequest], study.Trial] # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "stop_trial" not in self._stubs: - self._stubs["stop_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/StopTrial", + if 'stop_trial' not in self._stubs: + self._stubs['stop_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/StopTrial', request_serializer=vizier_service.StopTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["stop_trial"] + return self._stubs['stop_trial'] @property - def list_optimal_trials( - self, - ) -> Callable[ - [vizier_service.ListOptimalTrialsRequest], - vizier_service.ListOptimalTrialsResponse, - ]: + def list_optimal_trials(self) -> Callable[ + [vizier_service.ListOptimalTrialsRequest], + vizier_service.ListOptimalTrialsResponse]: r"""Return a callable for the list optimal trials method over gRPC. Lists the pareto-optimal Trials for multi-objective Study or the @@ -673,13 +639,15 @@ def list_optimal_trials( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_optimal_trials" not in self._stubs: - self._stubs["list_optimal_trials"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials", + if 'list_optimal_trials' not in self._stubs: + self._stubs['list_optimal_trials'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials', request_serializer=vizier_service.ListOptimalTrialsRequest.serialize, response_deserializer=vizier_service.ListOptimalTrialsResponse.deserialize, ) - return self._stubs["list_optimal_trials"] + return self._stubs['list_optimal_trials'] -__all__ = ("VizierServiceGrpcTransport",) +__all__ = ( + 'VizierServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py index 82e28342a4..fedbc26b71 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import study @@ -58,18 +58,16 @@ class VizierServiceGrpcAsyncIOTransport(VizierServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -95,24 +93,22 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs, + **kwargs ) - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -147,10 +143,10 @@ def __init__( ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -159,7 +155,10 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None if api_mtls_endpoint: warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) @@ -167,70 +166,50 @@ def __init__( warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, + scopes=self._scopes, ssl_credentials=self._ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -238,18 +217,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - self._operations_client = None + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -278,9 +247,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_study( - self, - ) -> Callable[[vizier_service.CreateStudyRequest], Awaitable[gca_study.Study]]: + def create_study(self) -> Callable[ + [vizier_service.CreateStudyRequest], + Awaitable[gca_study.Study]]: r"""Return a callable for the create study method over gRPC. Creates a Study. A resource name will be generated @@ -296,18 +265,18 @@ def create_study( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_study" not in self._stubs: - self._stubs["create_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy", + if 'create_study' not in self._stubs: + self._stubs['create_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy', request_serializer=vizier_service.CreateStudyRequest.serialize, response_deserializer=gca_study.Study.deserialize, ) - return self._stubs["create_study"] + return self._stubs['create_study'] @property - def get_study( - self, - ) -> Callable[[vizier_service.GetStudyRequest], Awaitable[study.Study]]: + def get_study(self) -> Callable[ + [vizier_service.GetStudyRequest], + Awaitable[study.Study]]: r"""Return a callable for the get study method over gRPC. Gets a Study by name. @@ -322,21 +291,18 @@ def get_study( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_study" not in self._stubs: - self._stubs["get_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/GetStudy", + if 'get_study' not in self._stubs: + self._stubs['get_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/GetStudy', request_serializer=vizier_service.GetStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs["get_study"] + return self._stubs['get_study'] @property - def list_studies( - self, - ) -> Callable[ - [vizier_service.ListStudiesRequest], - Awaitable[vizier_service.ListStudiesResponse], - ]: + def list_studies(self) -> Callable[ + [vizier_service.ListStudiesRequest], + Awaitable[vizier_service.ListStudiesResponse]]: r"""Return a callable for the list studies method over gRPC. Lists all the studies in a region for an associated @@ -352,18 +318,18 @@ def list_studies( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_studies" not in self._stubs: - self._stubs["list_studies"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/ListStudies", + if 'list_studies' not in self._stubs: + self._stubs['list_studies'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/ListStudies', request_serializer=vizier_service.ListStudiesRequest.serialize, response_deserializer=vizier_service.ListStudiesResponse.deserialize, ) - return self._stubs["list_studies"] + return self._stubs['list_studies'] @property - def delete_study( - self, - ) -> Callable[[vizier_service.DeleteStudyRequest], Awaitable[empty.Empty]]: + def delete_study(self) -> Callable[ + [vizier_service.DeleteStudyRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the delete study method over gRPC. Deletes a Study. @@ -378,18 +344,18 @@ def delete_study( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_study" not in self._stubs: - self._stubs["delete_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy", + if 'delete_study' not in self._stubs: + self._stubs['delete_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy', request_serializer=vizier_service.DeleteStudyRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["delete_study"] + return self._stubs['delete_study'] @property - def lookup_study( - self, - ) -> Callable[[vizier_service.LookupStudyRequest], Awaitable[study.Study]]: + def lookup_study(self) -> Callable[ + [vizier_service.LookupStudyRequest], + Awaitable[study.Study]]: r"""Return a callable for the lookup study method over gRPC. Looks a study up using the user-defined display_name field @@ -405,20 +371,18 @@ def lookup_study( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "lookup_study" not in self._stubs: - self._stubs["lookup_study"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy", + if 'lookup_study' not in self._stubs: + self._stubs['lookup_study'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy', request_serializer=vizier_service.LookupStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs["lookup_study"] + return self._stubs['lookup_study'] @property - def suggest_trials( - self, - ) -> Callable[ - [vizier_service.SuggestTrialsRequest], Awaitable[operations.Operation] - ]: + def suggest_trials(self) -> Callable[ + [vizier_service.SuggestTrialsRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the suggest trials method over gRPC. Adds one or more Trials to a Study, with parameter values @@ -437,18 +401,18 @@ def suggest_trials( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "suggest_trials" not in self._stubs: - self._stubs["suggest_trials"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials", + if 'suggest_trials' not in self._stubs: + self._stubs['suggest_trials'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials', request_serializer=vizier_service.SuggestTrialsRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["suggest_trials"] + return self._stubs['suggest_trials'] @property - def create_trial( - self, - ) -> Callable[[vizier_service.CreateTrialRequest], Awaitable[study.Trial]]: + def create_trial(self) -> Callable[ + [vizier_service.CreateTrialRequest], + Awaitable[study.Trial]]: r"""Return a callable for the create trial method over gRPC. Adds a user provided Trial to a Study. @@ -463,18 +427,18 @@ def create_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_trial" not in self._stubs: - self._stubs["create_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial", + if 'create_trial' not in self._stubs: + self._stubs['create_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial', request_serializer=vizier_service.CreateTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["create_trial"] + return self._stubs['create_trial'] @property - def get_trial( - self, - ) -> Callable[[vizier_service.GetTrialRequest], Awaitable[study.Trial]]: + def get_trial(self) -> Callable[ + [vizier_service.GetTrialRequest], + Awaitable[study.Trial]]: r"""Return a callable for the get trial method over gRPC. Gets a Trial. @@ -489,20 +453,18 @@ def get_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_trial" not in self._stubs: - self._stubs["get_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/GetTrial", + if 'get_trial' not in self._stubs: + self._stubs['get_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/GetTrial', request_serializer=vizier_service.GetTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["get_trial"] + return self._stubs['get_trial'] @property - def list_trials( - self, - ) -> Callable[ - [vizier_service.ListTrialsRequest], Awaitable[vizier_service.ListTrialsResponse] - ]: + def list_trials(self) -> Callable[ + [vizier_service.ListTrialsRequest], + Awaitable[vizier_service.ListTrialsResponse]]: r"""Return a callable for the list trials method over gRPC. Lists the Trials associated with a Study. @@ -517,18 +479,18 @@ def list_trials( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_trials" not in self._stubs: - self._stubs["list_trials"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/ListTrials", + if 'list_trials' not in self._stubs: + self._stubs['list_trials'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/ListTrials', request_serializer=vizier_service.ListTrialsRequest.serialize, response_deserializer=vizier_service.ListTrialsResponse.deserialize, ) - return self._stubs["list_trials"] + return self._stubs['list_trials'] @property - def add_trial_measurement( - self, - ) -> Callable[[vizier_service.AddTrialMeasurementRequest], Awaitable[study.Trial]]: + def add_trial_measurement(self) -> Callable[ + [vizier_service.AddTrialMeasurementRequest], + Awaitable[study.Trial]]: r"""Return a callable for the add trial measurement method over gRPC. Adds a measurement of the objective metrics to a @@ -545,18 +507,18 @@ def add_trial_measurement( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "add_trial_measurement" not in self._stubs: - self._stubs["add_trial_measurement"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement", + if 'add_trial_measurement' not in self._stubs: + self._stubs['add_trial_measurement'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement', request_serializer=vizier_service.AddTrialMeasurementRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["add_trial_measurement"] + return self._stubs['add_trial_measurement'] @property - def complete_trial( - self, - ) -> Callable[[vizier_service.CompleteTrialRequest], Awaitable[study.Trial]]: + def complete_trial(self) -> Callable[ + [vizier_service.CompleteTrialRequest], + Awaitable[study.Trial]]: r"""Return a callable for the complete trial method over gRPC. Marks a Trial as complete. @@ -571,18 +533,18 @@ def complete_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "complete_trial" not in self._stubs: - self._stubs["complete_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial", + if 'complete_trial' not in self._stubs: + self._stubs['complete_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial', request_serializer=vizier_service.CompleteTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["complete_trial"] + return self._stubs['complete_trial'] @property - def delete_trial( - self, - ) -> Callable[[vizier_service.DeleteTrialRequest], Awaitable[empty.Empty]]: + def delete_trial(self) -> Callable[ + [vizier_service.DeleteTrialRequest], + Awaitable[empty.Empty]]: r"""Return a callable for the delete trial method over gRPC. Deletes a Trial. @@ -597,21 +559,18 @@ def delete_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_trial" not in self._stubs: - self._stubs["delete_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial", + if 'delete_trial' not in self._stubs: + self._stubs['delete_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial', request_serializer=vizier_service.DeleteTrialRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["delete_trial"] + return self._stubs['delete_trial'] @property - def check_trial_early_stopping_state( - self, - ) -> Callable[ - [vizier_service.CheckTrialEarlyStoppingStateRequest], - Awaitable[operations.Operation], - ]: + def check_trial_early_stopping_state(self) -> Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], + Awaitable[operations.Operation]]: r"""Return a callable for the check trial early stopping state method over gRPC. @@ -630,20 +589,18 @@ def check_trial_early_stopping_state( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "check_trial_early_stopping_state" not in self._stubs: - self._stubs[ - "check_trial_early_stopping_state" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState", + if 'check_trial_early_stopping_state' not in self._stubs: + self._stubs['check_trial_early_stopping_state'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState', request_serializer=vizier_service.CheckTrialEarlyStoppingStateRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["check_trial_early_stopping_state"] + return self._stubs['check_trial_early_stopping_state'] @property - def stop_trial( - self, - ) -> Callable[[vizier_service.StopTrialRequest], Awaitable[study.Trial]]: + def stop_trial(self) -> Callable[ + [vizier_service.StopTrialRequest], + Awaitable[study.Trial]]: r"""Return a callable for the stop trial method over gRPC. Stops a Trial. @@ -658,21 +615,18 @@ def stop_trial( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "stop_trial" not in self._stubs: - self._stubs["stop_trial"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/StopTrial", + if 'stop_trial' not in self._stubs: + self._stubs['stop_trial'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/StopTrial', request_serializer=vizier_service.StopTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs["stop_trial"] + return self._stubs['stop_trial'] @property - def list_optimal_trials( - self, - ) -> Callable[ - [vizier_service.ListOptimalTrialsRequest], - Awaitable[vizier_service.ListOptimalTrialsResponse], - ]: + def list_optimal_trials(self) -> Callable[ + [vizier_service.ListOptimalTrialsRequest], + Awaitable[vizier_service.ListOptimalTrialsResponse]]: r"""Return a callable for the list optimal trials method over gRPC. Lists the pareto-optimal Trials for multi-objective Study or the @@ -690,13 +644,15 @@ def list_optimal_trials( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_optimal_trials" not in self._stubs: - self._stubs["list_optimal_trials"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials", + if 'list_optimal_trials' not in self._stubs: + self._stubs['list_optimal_trials'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials', request_serializer=vizier_service.ListOptimalTrialsRequest.serialize, response_deserializer=vizier_service.ListOptimalTrialsResponse.deserialize, ) - return self._stubs["list_optimal_trials"] + return self._stubs['list_optimal_trials'] -__all__ = ("VizierServiceGrpcAsyncIOTransport",) +__all__ = ( + 'VizierServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 2d2368df8c..8cc21f36ae 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -15,10 +15,24 @@ # limitations under the License. # -from .annotation import Annotation -from .annotation_spec import AnnotationSpec -from .batch_prediction_job import BatchPredictionJob -from .completion_stats import CompletionStats +from .annotation import ( + Annotation, +) +from .annotation_spec import ( + AnnotationSpec, +) +from .artifact import ( + Artifact, +) +from .batch_prediction_job import ( + BatchPredictionJob, +) +from .completion_stats import ( + CompletionStats, +) +from .context import ( + Context, +) from .custom_job import ( ContainerSpec, CustomJob, @@ -27,7 +41,9 @@ Scheduling, WorkerPoolSpec, ) -from .data_item import DataItem +from .data_item import ( + DataItem, +) from .data_labeling_job import ( ActiveLearningConfig, DataLabelingJob, @@ -59,8 +75,12 @@ ListDatasetsResponse, UpdateDatasetRequest, ) -from .deployed_model_ref import DeployedModelRef -from .encryption_spec import EncryptionSpec +from .deployed_model_ref import ( + DeployedModelRef, +) +from .encryption_spec import ( + EncryptionSpec, +) from .endpoint import ( DeployedModel, Endpoint, @@ -80,7 +100,15 @@ UndeployModelResponse, UpdateEndpointRequest, ) -from .env_var import EnvVar +from .env_var import ( + EnvVar, +) +from .event import ( + Event, +) +from .execution import ( + Execution, +) from .explanation import ( Attribution, Explanation, @@ -95,8 +123,15 @@ SmoothGradConfig, XraiAttribution, ) -from .explanation_metadata import ExplanationMetadata -from .hyperparameter_tuning_job import HyperparameterTuningJob +from .explanation_metadata import ( + ExplanationMetadata, +) +from .feature_monitoring_stats import ( + FeatureStatsAnomaly, +) +from .hyperparameter_tuning_job import ( + HyperparameterTuningJob, +) from .io import ( BigQueryDestination, BigQuerySource, @@ -113,14 +148,17 @@ CreateCustomJobRequest, CreateDataLabelingJobRequest, CreateHyperparameterTuningJobRequest, + CreateModelDeploymentMonitoringJobRequest, DeleteBatchPredictionJobRequest, DeleteCustomJobRequest, DeleteDataLabelingJobRequest, DeleteHyperparameterTuningJobRequest, + DeleteModelDeploymentMonitoringJobRequest, GetBatchPredictionJobRequest, GetCustomJobRequest, GetDataLabelingJobRequest, GetHyperparameterTuningJobRequest, + GetModelDeploymentMonitoringJobRequest, ListBatchPredictionJobsRequest, ListBatchPredictionJobsResponse, ListCustomJobsRequest, @@ -129,6 +167,17 @@ ListDataLabelingJobsResponse, ListHyperparameterTuningJobsRequest, ListHyperparameterTuningJobsResponse, + ListModelDeploymentMonitoringJobsRequest, + ListModelDeploymentMonitoringJobsResponse, + PauseModelDeploymentMonitoringJobRequest, + ResumeModelDeploymentMonitoringJobRequest, + SearchModelDeploymentMonitoringStatsAnomaliesRequest, + SearchModelDeploymentMonitoringStatsAnomaliesResponse, + UpdateModelDeploymentMonitoringJobOperationMetadata, + UpdateModelDeploymentMonitoringJobRequest, +) +from .lineage_subgraph import ( + LineageSubgraph, ) from .machine_resources import ( AutomaticResources, @@ -139,8 +188,55 @@ MachineSpec, ResourcesConsumed, ) -from .manual_batch_tuning_parameters import ManualBatchTuningParameters -from .migratable_resource import MigratableResource +from .manual_batch_tuning_parameters import ( + ManualBatchTuningParameters, +) +from .metadata_schema import ( + MetadataSchema, +) +from .metadata_service import ( + AddContextArtifactsAndExecutionsRequest, + AddContextArtifactsAndExecutionsResponse, + AddContextChildrenRequest, + AddContextChildrenResponse, + AddExecutionEventsRequest, + AddExecutionEventsResponse, + CreateArtifactRequest, + CreateContextRequest, + CreateExecutionRequest, + CreateMetadataSchemaRequest, + CreateMetadataStoreOperationMetadata, + CreateMetadataStoreRequest, + DeleteContextRequest, + DeleteMetadataStoreOperationMetadata, + DeleteMetadataStoreRequest, + GetArtifactRequest, + GetContextRequest, + GetExecutionRequest, + GetMetadataSchemaRequest, + GetMetadataStoreRequest, + ListArtifactsRequest, + ListArtifactsResponse, + ListContextsRequest, + ListContextsResponse, + ListExecutionsRequest, + ListExecutionsResponse, + ListMetadataSchemasRequest, + ListMetadataSchemasResponse, + ListMetadataStoresRequest, + ListMetadataStoresResponse, + QueryContextLineageSubgraphRequest, + QueryExecutionInputsAndOutputsRequest, + UpdateArtifactRequest, + UpdateContextRequest, + UpdateExecutionRequest, +) +from .metadata_store import ( + MetadataStore, +) +from .migratable_resource import ( + MigratableResource, +) from .migration_service import ( BatchMigrateResourcesOperationMetadata, BatchMigrateResourcesRequest, @@ -156,8 +252,26 @@ Port, PredictSchemata, ) -from .model_evaluation import ModelEvaluation -from .model_evaluation_slice import ModelEvaluationSlice +from .model_deployment_monitoring_job import ( + ModelDeploymentMonitoringBigQueryTable, + ModelDeploymentMonitoringJob, + ModelDeploymentMonitoringObjectiveConfig, + ModelDeploymentMonitoringScheduleConfig, + ModelMonitoringStatsAnomalies, + ModelDeploymentMonitoringObjectiveType, +) +from .model_evaluation import ( + ModelEvaluation, +) +from .model_evaluation_slice import ( + ModelEvaluationSlice, +) +from .model_monitoring import ( + ModelMonitoringAlertConfig, + ModelMonitoringObjectiveConfig, + SamplingStrategy, + ThresholdConfig, +) from .model_service import ( DeleteModelRequest, ExportModelOperationMetadata, @@ -195,7 +309,9 @@ PredictRequest, PredictResponse, ) -from .specialist_pool import SpecialistPool +from .specialist_pool import ( + SpecialistPool, +) from .specialist_pool_service import ( CreateSpecialistPoolOperationMetadata, CreateSpecialistPoolRequest, @@ -220,7 +336,9 @@ TimestampSplit, TrainingPipeline, ) -from .user_action_reference import UserActionReference +from .user_action_reference import ( + UserActionReference, +) from .vizier_service import ( AddTrialMeasurementRequest, CheckTrialEarlyStoppingStateMetatdata, @@ -247,197 +365,261 @@ ) __all__ = ( - "AcceleratorType", - "Annotation", - "AnnotationSpec", - "BatchPredictionJob", - "CompletionStats", - "ContainerSpec", - "CustomJob", - "CustomJobSpec", - "PythonPackageSpec", - "Scheduling", - "WorkerPoolSpec", - "DataItem", - "ActiveLearningConfig", - "DataLabelingJob", - "SampleConfig", - "TrainingConfig", - "Dataset", - "ExportDataConfig", - "ImportDataConfig", - "CreateDatasetOperationMetadata", - "CreateDatasetRequest", - "DeleteDatasetRequest", - "ExportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "GetAnnotationSpecRequest", - "GetDatasetRequest", - "ImportDataOperationMetadata", - "ImportDataRequest", - "ImportDataResponse", - "ListAnnotationsRequest", - "ListAnnotationsResponse", - "ListDataItemsRequest", - "ListDataItemsResponse", - "ListDatasetsRequest", - "ListDatasetsResponse", - "UpdateDatasetRequest", - "DeployedModelRef", - "EncryptionSpec", - "DeployedModel", - "Endpoint", - "CreateEndpointOperationMetadata", - "CreateEndpointRequest", - "DeleteEndpointRequest", - "DeployModelOperationMetadata", - "DeployModelRequest", - "DeployModelResponse", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UndeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UpdateEndpointRequest", - "EnvVar", - "Attribution", - "Explanation", - "ExplanationMetadataOverride", - "ExplanationParameters", - "ExplanationSpec", - "ExplanationSpecOverride", - "FeatureNoiseSigma", - "IntegratedGradientsAttribution", - "ModelExplanation", - "SampledShapleyAttribution", - "SmoothGradConfig", - "XraiAttribution", - "ExplanationMetadata", - "HyperparameterTuningJob", - "BigQueryDestination", - "BigQuerySource", - "ContainerRegistryDestination", - "GcsDestination", - "GcsSource", - "CancelBatchPredictionJobRequest", - "CancelCustomJobRequest", - "CancelDataLabelingJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "CreateCustomJobRequest", - "CreateDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "DeleteBatchPredictionJobRequest", - "DeleteCustomJobRequest", - "DeleteDataLabelingJobRequest", - "DeleteHyperparameterTuningJobRequest", - "GetBatchPredictionJobRequest", - "GetCustomJobRequest", - "GetDataLabelingJobRequest", - "GetHyperparameterTuningJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "JobState", - "AutomaticResources", - "AutoscalingMetricSpec", - "BatchDedicatedResources", - "DedicatedResources", - "DiskSpec", - "MachineSpec", - "ResourcesConsumed", - "ManualBatchTuningParameters", - "MigratableResource", - "BatchMigrateResourcesOperationMetadata", - "BatchMigrateResourcesRequest", - "BatchMigrateResourcesResponse", - "MigrateResourceRequest", - "MigrateResourceResponse", - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", - "Model", - "ModelContainerSpec", - "Port", - "PredictSchemata", - "ModelEvaluation", - "ModelEvaluationSlice", - "DeleteModelRequest", - "ExportModelOperationMetadata", - "ExportModelRequest", - "ExportModelResponse", - "GetModelEvaluationRequest", - "GetModelEvaluationSliceRequest", - "GetModelRequest", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", - "UploadModelOperationMetadata", - "UploadModelRequest", - "UploadModelResponse", - "DeleteOperationMetadata", - "GenericOperationMetadata", - "CancelTrainingPipelineRequest", - "CreateTrainingPipelineRequest", - "DeleteTrainingPipelineRequest", - "GetTrainingPipelineRequest", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "PipelineState", - "ExplainRequest", - "ExplainResponse", - "PredictRequest", - "PredictResponse", - "SpecialistPool", - "CreateSpecialistPoolOperationMetadata", - "CreateSpecialistPoolRequest", - "DeleteSpecialistPoolRequest", - "GetSpecialistPoolRequest", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "UpdateSpecialistPoolOperationMetadata", - "UpdateSpecialistPoolRequest", - "Measurement", - "Study", - "StudySpec", - "Trial", - "FilterSplit", - "FractionSplit", - "InputDataConfig", - "PredefinedSplit", - "TimestampSplit", - "TrainingPipeline", - "UserActionReference", - "AddTrialMeasurementRequest", - "CheckTrialEarlyStoppingStateMetatdata", - "CheckTrialEarlyStoppingStateRequest", - "CheckTrialEarlyStoppingStateResponse", - "CompleteTrialRequest", - "CreateStudyRequest", - "CreateTrialRequest", - "DeleteStudyRequest", - "DeleteTrialRequest", - "GetStudyRequest", - "GetTrialRequest", - "ListOptimalTrialsRequest", - "ListOptimalTrialsResponse", - "ListStudiesRequest", - "ListStudiesResponse", - "ListTrialsRequest", - "ListTrialsResponse", - "LookupStudyRequest", - "StopTrialRequest", - "SuggestTrialsMetadata", - "SuggestTrialsRequest", - "SuggestTrialsResponse", + 'AcceleratorType', + 'Annotation', + 'AnnotationSpec', + 'Artifact', + 'BatchPredictionJob', + 'CompletionStats', + 'Context', + 'ContainerSpec', + 'CustomJob', + 'CustomJobSpec', + 'PythonPackageSpec', + 'Scheduling', + 'WorkerPoolSpec', + 'DataItem', + 'ActiveLearningConfig', + 'DataLabelingJob', + 'SampleConfig', + 'TrainingConfig', + 'Dataset', + 'ExportDataConfig', + 'ImportDataConfig', + 'CreateDatasetOperationMetadata', + 'CreateDatasetRequest', + 'DeleteDatasetRequest', + 'ExportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'GetAnnotationSpecRequest', + 'GetDatasetRequest', + 'ImportDataOperationMetadata', + 'ImportDataRequest', + 'ImportDataResponse', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'UpdateDatasetRequest', + 'DeployedModelRef', + 'EncryptionSpec', + 'DeployedModel', + 'Endpoint', + 'CreateEndpointOperationMetadata', + 'CreateEndpointRequest', + 'DeleteEndpointRequest', + 'DeployModelOperationMetadata', + 'DeployModelRequest', + 'DeployModelResponse', + 'GetEndpointRequest', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'UndeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UpdateEndpointRequest', + 'EnvVar', + 'Event', + 'Execution', + 'Attribution', + 'Explanation', + 'ExplanationMetadataOverride', + 'ExplanationParameters', + 'ExplanationSpec', + 'ExplanationSpecOverride', + 'FeatureNoiseSigma', + 'IntegratedGradientsAttribution', + 'ModelExplanation', + 'SampledShapleyAttribution', + 'SmoothGradConfig', + 'XraiAttribution', + 'ExplanationMetadata', + 'FeatureStatsAnomaly', + 'HyperparameterTuningJob', + 'BigQueryDestination', + 'BigQuerySource', + 'ContainerRegistryDestination', + 'GcsDestination', + 'GcsSource', + 'CancelBatchPredictionJobRequest', + 'CancelCustomJobRequest', + 'CancelDataLabelingJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CreateBatchPredictionJobRequest', + 'CreateCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'CreateHyperparameterTuningJobRequest', + 'CreateModelDeploymentMonitoringJobRequest', + 'DeleteBatchPredictionJobRequest', + 'DeleteCustomJobRequest', + 'DeleteDataLabelingJobRequest', + 'DeleteHyperparameterTuningJobRequest', + 'DeleteModelDeploymentMonitoringJobRequest', + 'GetBatchPredictionJobRequest', + 'GetCustomJobRequest', + 'GetDataLabelingJobRequest', + 'GetHyperparameterTuningJobRequest', + 'GetModelDeploymentMonitoringJobRequest', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'ListModelDeploymentMonitoringJobsRequest', + 'ListModelDeploymentMonitoringJobsResponse', + 'PauseModelDeploymentMonitoringJobRequest', + 'ResumeModelDeploymentMonitoringJobRequest', + 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', + 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', + 'UpdateModelDeploymentMonitoringJobOperationMetadata', + 'UpdateModelDeploymentMonitoringJobRequest', + 'JobState', + 'LineageSubgraph', + 'AutomaticResources', + 'AutoscalingMetricSpec', + 'BatchDedicatedResources', + 'DedicatedResources', + 'DiskSpec', + 'MachineSpec', + 'ResourcesConsumed', + 'ManualBatchTuningParameters', + 'MetadataSchema', + 'AddContextArtifactsAndExecutionsRequest', + 'AddContextArtifactsAndExecutionsResponse', + 'AddContextChildrenRequest', + 'AddContextChildrenResponse', + 'AddExecutionEventsRequest', + 'AddExecutionEventsResponse', + 'CreateArtifactRequest', + 'CreateContextRequest', + 'CreateExecutionRequest', + 'CreateMetadataSchemaRequest', + 'CreateMetadataStoreOperationMetadata', + 'CreateMetadataStoreRequest', + 'DeleteContextRequest', + 'DeleteMetadataStoreOperationMetadata', + 'DeleteMetadataStoreRequest', + 'GetArtifactRequest', + 'GetContextRequest', + 'GetExecutionRequest', + 'GetMetadataSchemaRequest', + 'GetMetadataStoreRequest', + 'ListArtifactsRequest', + 'ListArtifactsResponse', + 'ListContextsRequest', + 'ListContextsResponse', + 'ListExecutionsRequest', + 'ListExecutionsResponse', + 'ListMetadataSchemasRequest', + 'ListMetadataSchemasResponse', + 'ListMetadataStoresRequest', + 'ListMetadataStoresResponse', + 'QueryContextLineageSubgraphRequest', + 'QueryExecutionInputsAndOutputsRequest', + 'UpdateArtifactRequest', + 'UpdateContextRequest', + 'UpdateExecutionRequest', + 'MetadataStore', + 'MigratableResource', + 'BatchMigrateResourcesOperationMetadata', + 'BatchMigrateResourcesRequest', + 'BatchMigrateResourcesResponse', + 'MigrateResourceRequest', + 'MigrateResourceResponse', + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'Model', + 'ModelContainerSpec', + 'Port', + 'PredictSchemata', + 'ModelDeploymentMonitoringBigQueryTable', + 'ModelDeploymentMonitoringJob', + 'ModelDeploymentMonitoringObjectiveConfig', + 'ModelDeploymentMonitoringScheduleConfig', + 'ModelMonitoringStatsAnomalies', + 'ModelDeploymentMonitoringObjectiveType', + 'ModelEvaluation', + 'ModelEvaluationSlice', + 'ModelMonitoringAlertConfig', + 'ModelMonitoringObjectiveConfig', + 'SamplingStrategy', + 'ThresholdConfig', + 'DeleteModelRequest', + 'ExportModelOperationMetadata', + 'ExportModelRequest', + 'ExportModelResponse', + 'GetModelEvaluationRequest', + 'GetModelEvaluationSliceRequest', + 'GetModelRequest', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'ListModelsRequest', + 'ListModelsResponse', + 'UpdateModelRequest', + 'UploadModelOperationMetadata', + 'UploadModelRequest', + 'UploadModelResponse', + 'DeleteOperationMetadata', + 'GenericOperationMetadata', + 'CancelTrainingPipelineRequest', + 'CreateTrainingPipelineRequest', + 'DeleteTrainingPipelineRequest', + 'GetTrainingPipelineRequest', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'PipelineState', + 'ExplainRequest', + 'ExplainResponse', + 'PredictRequest', + 'PredictResponse', + 'SpecialistPool', + 'CreateSpecialistPoolOperationMetadata', + 'CreateSpecialistPoolRequest', + 'DeleteSpecialistPoolRequest', + 'GetSpecialistPoolRequest', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'UpdateSpecialistPoolOperationMetadata', + 'UpdateSpecialistPoolRequest', + 'Measurement', + 'Study', + 'StudySpec', + 'Trial', + 'FilterSplit', + 'FractionSplit', + 'InputDataConfig', + 'PredefinedSplit', + 'TimestampSplit', + 'TrainingPipeline', + 'UserActionReference', + 'AddTrialMeasurementRequest', + 'CheckTrialEarlyStoppingStateMetatdata', + 'CheckTrialEarlyStoppingStateRequest', + 'CheckTrialEarlyStoppingStateResponse', + 'CompleteTrialRequest', + 'CreateStudyRequest', + 'CreateTrialRequest', + 'DeleteStudyRequest', + 'DeleteTrialRequest', + 'GetStudyRequest', + 'GetTrialRequest', + 'ListOptimalTrialsRequest', + 'ListOptimalTrialsResponse', + 'ListStudiesRequest', + 'ListStudiesResponse', + 'ListTrialsRequest', + 'ListTrialsResponse', + 'LookupStudyRequest', + 'StopTrialRequest', + 'SuggestTrialsMetadata', + 'SuggestTrialsRequest', + 'SuggestTrialsResponse', ) diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py index 8c6968952c..65471c7234 100644 --- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"AcceleratorType",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'AcceleratorType', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index a42ef0da82..4b769480a8 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -24,7 +24,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"Annotation",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Annotation', + }, ) @@ -91,16 +94,22 @@ class Annotation(proto.Message): payload_schema_uri = proto.Field(proto.STRING, number=2) - payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + payload = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) etag = proto.Field(proto.STRING, number=8) - annotation_source = proto.Field( - proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, + annotation_source = proto.Field(proto.MESSAGE, number=5, + message=user_action_reference.UserActionReference, ) labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index e921e25971..b60bcebb5f 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -22,7 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"AnnotationSpec",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'AnnotationSpec', + }, ) @@ -55,9 +58,13 @@ class AnnotationSpec(proto.Message): display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/artifact.py b/google/cloud/aiplatform_v1beta1/types/artifact.py new file mode 100644 index 0000000000..b35ae286d7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/artifact.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Artifact', + }, +) + + +class Artifact(proto.Message): + r"""Instance of a general artifact. + + Attributes: + name (str): + Output only. The resource name of the + Artifact. + display_name (str): + User provided display name of the Artifact. + May be up to 128 Unicode characters. + uri (str): + The uniform resource identifier of the + artifact file. May be empty if there is no + actual artifact file. + etag (str): + An eTag used to perform consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Artifact.LabelsEntry]): + The labels with user-defined metadata to organize your + Artifacts. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. No more than 64 user labels can be + associated with one Artifact (System labels are excluded). + + See https://goo.gl/xmQnxf for more information and examples + of labels. System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. Following + system labels exist for each Artifact: + + - "aiplatform.googleapis.com/schema_title": + + - output only, its value is the title of the Artifact + schema provided either by [instance_schema_uri][] or + [instance_schema][]. + + - "aiplatform.googleapis.com/schema_version": + + - output only, its value is the schema version of the + Artifact schema provided either by + [instance_schema_uri][] or [instance_schema][]. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Artifact was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Artifact was + last updated. + state (google.cloud.aiplatform_v1beta1.types.Artifact.State): + The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as AI Platform + Pipelines), and the system does not prescribe or + check the validity of state transitions. + schema_title (str): + The title of the schema describing the + metadata. + Schema title and version is expected to be + registered in earlier Create Schema calls. And + both are used together as unique identifiers to + identify schemas within the local metadata + store. + schema_version (str): + The version of the schema in schema_name to use. + + Schema title and version is expected to be registered in + earlier Create Schema calls. And both are used together as + unique identifiers to identify schemas within the local + metadata store. + metadata (google.protobuf.struct_pb2.Struct): + Properties of the Artifact. + description (str): + Description of the Artifact + """ + class State(proto.Enum): + r"""Describes the state of the Artifact.""" + STATE_UNSPECIFIED = 0 + PENDING = 1 + LIVE = 2 + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + uri = proto.Field(proto.STRING, number=6) + + etag = proto.Field(proto.STRING, number=9) + + labels = proto.MapField(proto.STRING, proto.STRING, number=10) + + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) + + state = proto.Field(proto.ENUM, number=13, + enum=State, + ) + + schema_title = proto.Field(proto.STRING, number=14) + + schema_version = proto.Field(proto.STRING, number=15) + + metadata = proto.Field(proto.MESSAGE, number=16, + message=struct.Struct, + ) + + description = proto.Field(proto.STRING, number=17) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 9c79349b9e..b2bcab9302 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -18,24 +18,23 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - completion_stats as gca_completion_stats, -) +from google.cloud.aiplatform_v1beta1.types import completion_stats as gca_completion_stats from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import ( - manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, -) +from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"BatchPredictionJob",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'BatchPredictionJob', + }, ) @@ -191,7 +190,6 @@ class BatchPredictionJob(proto.Message): resources created by the BatchPredictionJob will be encrypted with the provided encryption key. """ - class InputConfig(proto.Message): r"""Configures the input to ``BatchPredictionJob``. @@ -218,12 +216,12 @@ class InputConfig(proto.Message): ``supported_input_storage_formats``. """ - gcs_source = proto.Field( - proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, + gcs_source = proto.Field(proto.MESSAGE, number=2, oneof='source', + message=io.GcsSource, ) - bigquery_source = proto.Field( - proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, + bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', + message=io.BigQuerySource, ) instances_format = proto.Field(proto.STRING, number=1) @@ -265,9 +263,9 @@ class OutputConfig(proto.Message): which as value has ```google.rpc.Status`` `__ containing only ``code`` and ``message`` fields. bigquery_destination (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): - The BigQuery project location where the output is to be - written to. In the given project a new dataset is created - with name + The BigQuery project or dataset location where the output is + to be written to. If project is provided, a new dataset is + created with name ``prediction__`` where is made BigQuery-dataset-name compatible (for example, most special characters become underscores), and timestamp is in @@ -293,14 +291,11 @@ class OutputConfig(proto.Message): ``supported_output_storage_formats``. """ - gcs_destination = proto.Field( - proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, + gcs_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', + message=io.GcsDestination, ) - bigquery_destination = proto.Field( - proto.MESSAGE, - number=3, - oneof="destination", + bigquery_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', message=io.BigQueryDestination, ) @@ -321,13 +316,9 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field( - proto.STRING, number=1, oneof="output_location" - ) + gcs_output_directory = proto.Field(proto.STRING, number=1, oneof='output_location') - bigquery_output_dataset = proto.Field( - proto.STRING, number=2, oneof="output_location" - ) + bigquery_output_dataset = proto.Field(proto.STRING, number=2, oneof='output_location') name = proto.Field(proto.STRING, number=1) @@ -335,58 +326,76 @@ class OutputInfo(proto.Message): model = proto.Field(proto.STRING, number=3) - input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) + input_config = proto.Field(proto.MESSAGE, number=4, + message=InputConfig, + ) - model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + model_parameters = proto.Field(proto.MESSAGE, number=5, + message=struct.Value, + ) - output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) + output_config = proto.Field(proto.MESSAGE, number=6, + message=OutputConfig, + ) - dedicated_resources = proto.Field( - proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, + dedicated_resources = proto.Field(proto.MESSAGE, number=7, + message=machine_resources.BatchDedicatedResources, ) - manual_batch_tuning_parameters = proto.Field( - proto.MESSAGE, - number=8, + manual_batch_tuning_parameters = proto.Field(proto.MESSAGE, number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) generate_explanation = proto.Field(proto.BOOL, number=23) - explanation_spec = proto.Field( - proto.MESSAGE, number=25, message=explanation.ExplanationSpec, + explanation_spec = proto.Field(proto.MESSAGE, number=25, + message=explanation.ExplanationSpec, ) - output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) + output_info = proto.Field(proto.MESSAGE, number=9, + message=OutputInfo, + ) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=10, + enum=job_state.JobState, + ) - error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=11, + message=status.Status, + ) - partial_failures = proto.RepeatedField( - proto.MESSAGE, number=12, message=status.Status, + partial_failures = proto.RepeatedField(proto.MESSAGE, number=12, + message=status.Status, ) - resources_consumed = proto.Field( - proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, + resources_consumed = proto.Field(proto.MESSAGE, number=13, + message=machine_resources.ResourcesConsumed, ) - completion_stats = proto.Field( - proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, + completion_stats = proto.Field(proto.MESSAGE, number=14, + message=gca_completion_stats.CompletionStats, ) - create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=15, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=16, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=17, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=18, + message=timestamp.Timestamp, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=19) - encryption_spec = proto.Field( - proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=24, + message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/completion_stats.py b/google/cloud/aiplatform_v1beta1/types/completion_stats.py index 165be59634..3874f412df 100644 --- a/google/cloud/aiplatform_v1beta1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/completion_stats.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"CompletionStats",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'CompletionStats', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/context.py b/google/cloud/aiplatform_v1beta1/types/context.py new file mode 100644 index 0000000000..59f5289b48 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/context.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Context', + }, +) + + +class Context(proto.Message): + r"""Instance of a general context. + + Attributes: + name (str): + Output only. The resource name of the + Context. + display_name (str): + User provided display name of the Context. + May be up to 128 Unicode characters. + etag (str): + An eTag used to perform consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Context.LabelsEntry]): + The labels with user-defined metadata to organize your + Contexts. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. No more than 64 user labels can be + associated with one Context (System labels are excluded). + + See https://goo.gl/xmQnxf for more information and examples + of labels. System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. Following + system labels exist for each Context: + + - "aiplatform.googleapis.com/schema_title": + + - output only, its value is the title of the Context + schema provided either by [instance_schema_uri][] or + [instance_schema][]. + + - "aiplatform.googleapis.com/schema_version": + + - output only, its value is the schema version of the + Context schema provided either by + [instance_schema_uri][] or [instance_schema][]. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Context was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Context was + last updated. + parent_contexts (Sequence[str]): + Output only. A list of resource names of Contexts that are + parents of this Context. A Context may have at most 10 + parent_contexts. + schema_title (str): + The title of the schema describing the + metadata. + Schema title and version is expected to be + registered in earlier Create Schema calls. And + both are used together as unique identifiers to + identify schemas within the local metadata + store. + schema_version (str): + The version of the schema in schema_name to use. + + Schema title and version is expected to be registered in + earlier Create Schema calls. And both are used together as + unique identifiers to identify schemas within the local + metadata store. + metadata (google.protobuf.struct_pb2.Struct): + Properties of the Context. + description (str): + Description of the Context + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + etag = proto.Field(proto.STRING, number=8) + + labels = proto.MapField(proto.STRING, proto.STRING, number=9) + + create_time = proto.Field(proto.MESSAGE, number=10, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) + + parent_contexts = proto.RepeatedField(proto.STRING, number=12) + + schema_title = proto.Field(proto.STRING, number=13) + + schema_version = proto.Field(proto.STRING, number=14) + + metadata = proto.Field(proto.MESSAGE, number=15, + message=struct.Struct, + ) + + description = proto.Field(proto.STRING, number=16) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 1d148b7777..9de4e3b5fa 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -28,14 +28,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CustomJob", - "CustomJobSpec", - "WorkerPoolSpec", - "ContainerSpec", - "PythonPackageSpec", - "Scheduling", + 'CustomJob', + 'CustomJobSpec', + 'WorkerPoolSpec', + 'ContainerSpec', + 'PythonPackageSpec', + 'Scheduling', }, ) @@ -95,24 +95,38 @@ class CustomJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) + job_spec = proto.Field(proto.MESSAGE, number=4, + message='CustomJobSpec', + ) - state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=5, + enum=job_state.JobState, + ) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=10, + message=status.Status, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=11) - encryption_spec = proto.Field( - proto.MESSAGE, number=12, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=12, + message=gca_encryption_spec.EncryptionSpec, ) @@ -177,18 +191,20 @@ class CustomJobSpec(proto.Message): ``//logs/`` """ - worker_pool_specs = proto.RepeatedField( - proto.MESSAGE, number=1, message="WorkerPoolSpec", + worker_pool_specs = proto.RepeatedField(proto.MESSAGE, number=1, + message='WorkerPoolSpec', ) - scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) + scheduling = proto.Field(proto.MESSAGE, number=3, + message='Scheduling', + ) service_account = proto.Field(proto.STRING, number=4) network = proto.Field(proto.STRING, number=5) - base_output_directory = proto.Field( - proto.MESSAGE, number=6, message=io.GcsDestination, + base_output_directory = proto.Field(proto.MESSAGE, number=6, + message=io.GcsDestination, ) @@ -210,22 +226,22 @@ class WorkerPoolSpec(proto.Message): Disk spec. """ - container_spec = proto.Field( - proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", + container_spec = proto.Field(proto.MESSAGE, number=6, oneof='task', + message='ContainerSpec', ) - python_package_spec = proto.Field( - proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", + python_package_spec = proto.Field(proto.MESSAGE, number=7, oneof='task', + message='PythonPackageSpec', ) - machine_spec = proto.Field( - proto.MESSAGE, number=1, message=machine_resources.MachineSpec, + machine_spec = proto.Field(proto.MESSAGE, number=1, + message=machine_resources.MachineSpec, ) replica_count = proto.Field(proto.INT64, number=2) - disk_spec = proto.Field( - proto.MESSAGE, number=5, message=machine_resources.DiskSpec, + disk_spec = proto.Field(proto.MESSAGE, number=5, + message=machine_resources.DiskSpec, ) @@ -302,7 +318,9 @@ class Scheduling(proto.Message): to workers leaving and joining a job. """ - timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + timeout = proto.Field(proto.MESSAGE, number=1, + message=duration.Duration, + ) restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index a12776f06c..5c50d8e526 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -23,7 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"DataItem",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'DataItem', + }, ) @@ -70,13 +73,19 @@ class DataItem(proto.Message): name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=2, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=3) - payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + payload = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index d750f53e66..0b123cc88e 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -27,12 +27,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "DataLabelingJob", - "ActiveLearningConfig", - "SampleConfig", - "TrainingConfig", + 'DataLabelingJob', + 'ActiveLearningConfig', + 'SampleConfig', + 'TrainingConfig', }, ) @@ -154,30 +154,42 @@ class DataLabelingJob(proto.Message): inputs_schema_uri = proto.Field(proto.STRING, number=6) - inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) + inputs = proto.Field(proto.MESSAGE, number=7, + message=struct.Value, + ) - state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=8, + enum=job_state.JobState, + ) labeling_progress = proto.Field(proto.INT32, number=13) - current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) + current_spend = proto.Field(proto.MESSAGE, number=14, + message=money.Money, + ) - create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=10, + message=timestamp.Timestamp, + ) - error = proto.Field(proto.MESSAGE, number=22, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=22, + message=status.Status, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=11) specialist_pools = proto.RepeatedField(proto.STRING, number=16) - encryption_spec = proto.Field( - proto.MESSAGE, number=20, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=20, + message=gca_encryption_spec.EncryptionSpec, ) - active_learning_config = proto.Field( - proto.MESSAGE, number=21, message="ActiveLearningConfig", + active_learning_config = proto.Field(proto.MESSAGE, number=21, + message='ActiveLearningConfig', ) @@ -206,17 +218,17 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field( - proto.INT64, number=1, oneof="human_labeling_budget" - ) + max_data_item_count = proto.Field(proto.INT64, number=1, oneof='human_labeling_budget') - max_data_item_percentage = proto.Field( - proto.INT32, number=2, oneof="human_labeling_budget" - ) + max_data_item_percentage = proto.Field(proto.INT32, number=2, oneof='human_labeling_budget') - sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) + sample_config = proto.Field(proto.MESSAGE, number=3, + message='SampleConfig', + ) - training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) + training_config = proto.Field(proto.MESSAGE, number=4, + message='TrainingConfig', + ) class SampleConfig(proto.Message): @@ -237,7 +249,6 @@ class SampleConfig(proto.Message): strategy will decide which data should be selected for human labeling in every batch. """ - class SampleStrategy(proto.Enum): r"""Sample strategy decides which subset of DataItems should be selected for human labeling in every batch. @@ -245,15 +256,13 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field( - proto.INT32, number=1, oneof="initial_batch_sample_size" - ) + initial_batch_sample_percentage = proto.Field(proto.INT32, number=1, oneof='initial_batch_sample_size') - following_batch_sample_percentage = proto.Field( - proto.INT32, number=3, oneof="following_batch_sample_size" - ) + following_batch_sample_percentage = proto.Field(proto.INT32, number=3, oneof='following_batch_sample_size') - sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) + sample_strategy = proto.Field(proto.ENUM, number=5, + enum=SampleStrategy, + ) class TrainingConfig(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 9fa17fcb3a..969596f706 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -25,8 +25,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"Dataset", "ImportDataConfig", "ExportDataConfig",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Dataset', + 'ImportDataConfig', + 'ExportDataConfig', + }, ) @@ -94,18 +98,24 @@ class Dataset(proto.Message): metadata_schema_uri = proto.Field(proto.STRING, number=3) - metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) + metadata = proto.Field(proto.MESSAGE, number=8, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) etag = proto.Field(proto.STRING, number=6) labels = proto.MapField(proto.STRING, proto.STRING, number=7) - encryption_spec = proto.Field( - proto.MESSAGE, number=11, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=11, + message=gca_encryption_spec.EncryptionSpec, ) @@ -141,8 +151,8 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field( - proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, + gcs_source = proto.Field(proto.MESSAGE, number=1, oneof='source', + message=io.GcsSource, ) data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) @@ -175,8 +185,8 @@ class ExportDataConfig(proto.Message): ``ListAnnotations``. """ - gcs_destination = proto.Field( - proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, + gcs_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', + message=io.GcsDestination, ) annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 1ab94b8c89..73b9b56d5a 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -26,26 +26,26 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateDatasetRequest", - "CreateDatasetOperationMetadata", - "GetDatasetRequest", - "UpdateDatasetRequest", - "ListDatasetsRequest", - "ListDatasetsResponse", - "DeleteDatasetRequest", - "ImportDataRequest", - "ImportDataResponse", - "ImportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportDataOperationMetadata", - "ListDataItemsRequest", - "ListDataItemsResponse", - "GetAnnotationSpecRequest", - "ListAnnotationsRequest", - "ListAnnotationsResponse", + 'CreateDatasetRequest', + 'CreateDatasetOperationMetadata', + 'GetDatasetRequest', + 'UpdateDatasetRequest', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'DeleteDatasetRequest', + 'ImportDataRequest', + 'ImportDataResponse', + 'ImportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'ExportDataOperationMetadata', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'GetAnnotationSpecRequest', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', }, ) @@ -65,7 +65,9 @@ class CreateDatasetRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) + dataset = proto.Field(proto.MESSAGE, number=2, + message=gca_dataset.Dataset, + ) class CreateDatasetOperationMetadata(proto.Message): @@ -77,8 +79,8 @@ class CreateDatasetOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -95,7 +97,9 @@ class GetDatasetRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class UpdateDatasetRequest(proto.Message): @@ -117,9 +121,13 @@ class UpdateDatasetRequest(proto.Message): - ``labels`` """ - dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) + dataset = proto.Field(proto.MESSAGE, number=1, + message=gca_dataset.Dataset, + ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class ListDatasetsRequest(proto.Message): @@ -171,7 +179,9 @@ class ListDatasetsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -192,8 +202,8 @@ class ListDatasetsResponse(proto.Message): def raw_page(self): return self - datasets = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_dataset.Dataset, + datasets = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_dataset.Dataset, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -229,8 +239,8 @@ class ImportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - import_configs = proto.RepeatedField( - proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, + import_configs = proto.RepeatedField(proto.MESSAGE, number=2, + message=gca_dataset.ImportDataConfig, ) @@ -249,8 +259,8 @@ class ImportDataOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -268,8 +278,8 @@ class ExportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - export_config = proto.Field( - proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, + export_config = proto.Field(proto.MESSAGE, number=2, + message=gca_dataset.ExportDataConfig, ) @@ -299,8 +309,8 @@ class ExportDataOperationMetadata(proto.Message): the directory. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -337,7 +347,9 @@ class ListDataItemsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -358,8 +370,8 @@ class ListDataItemsResponse(proto.Message): def raw_page(self): return self - data_items = proto.RepeatedField( - proto.MESSAGE, number=1, message=data_item.DataItem, + data_items = proto.RepeatedField(proto.MESSAGE, number=1, + message=data_item.DataItem, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -379,7 +391,9 @@ class GetAnnotationSpecRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class ListAnnotationsRequest(proto.Message): @@ -413,7 +427,9 @@ class ListAnnotationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -434,8 +450,8 @@ class ListAnnotationsResponse(proto.Message): def raw_page(self): return self - annotations = proto.RepeatedField( - proto.MESSAGE, number=1, message=annotation.Annotation, + annotations = proto.RepeatedField(proto.MESSAGE, number=1, + message=annotation.Annotation, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py index b0ec7010a2..aa5c8424aa 100644 --- a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"DeployedModelRef",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'DeployedModelRef', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/encryption_spec.py b/google/cloud/aiplatform_v1beta1/types/encryption_spec.py index 0d41d39a0b..398d935aa4 100644 --- a/google/cloud/aiplatform_v1beta1/types/encryption_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/encryption_spec.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"EncryptionSpec",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'EncryptionSpec', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 40ede068f3..85393de4b8 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -25,7 +25,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"Endpoint", "DeployedModel",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Endpoint', + 'DeployedModel', + }, ) @@ -93,8 +97,8 @@ class Endpoint(proto.Message): description = proto.Field(proto.STRING, number=3) - deployed_models = proto.RepeatedField( - proto.MESSAGE, number=4, message="DeployedModel", + deployed_models = proto.RepeatedField(proto.MESSAGE, number=4, + message='DeployedModel', ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) @@ -103,12 +107,16 @@ class Endpoint(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=7) - create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) - encryption_spec = proto.Field( - proto.MESSAGE, number=10, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=10, + message=gca_encryption_spec.EncryptionSpec, ) @@ -184,17 +192,11 @@ class DeployedModel(proto.Message): option. """ - dedicated_resources = proto.Field( - proto.MESSAGE, - number=7, - oneof="prediction_resources", + dedicated_resources = proto.Field(proto.MESSAGE, number=7, oneof='prediction_resources', message=machine_resources.DedicatedResources, ) - automatic_resources = proto.Field( - proto.MESSAGE, - number=8, - oneof="prediction_resources", + automatic_resources = proto.Field(proto.MESSAGE, number=8, oneof='prediction_resources', message=machine_resources.AutomaticResources, ) @@ -204,10 +206,12 @@ class DeployedModel(proto.Message): display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) - explanation_spec = proto.Field( - proto.MESSAGE, number=9, message=explanation.ExplanationSpec, + explanation_spec = proto.Field(proto.MESSAGE, number=9, + message=explanation.ExplanationSpec, ) service_account = proto.Field(proto.STRING, number=11) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index fe7442ab2a..9fa5944c5f 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -24,21 +24,21 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateEndpointRequest", - "CreateEndpointOperationMetadata", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UpdateEndpointRequest", - "DeleteEndpointRequest", - "DeployModelRequest", - "DeployModelResponse", - "DeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UndeployModelOperationMetadata", + 'CreateEndpointRequest', + 'CreateEndpointOperationMetadata', + 'GetEndpointRequest', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'UpdateEndpointRequest', + 'DeleteEndpointRequest', + 'DeployModelRequest', + 'DeployModelResponse', + 'DeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UndeployModelOperationMetadata', }, ) @@ -58,7 +58,9 @@ class CreateEndpointRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) + endpoint = proto.Field(proto.MESSAGE, number=2, + message=gca_endpoint.Endpoint, + ) class CreateEndpointOperationMetadata(proto.Message): @@ -70,8 +72,8 @@ class CreateEndpointOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -141,7 +143,9 @@ class ListEndpointsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListEndpointsResponse(proto.Message): @@ -161,8 +165,8 @@ class ListEndpointsResponse(proto.Message): def raw_page(self): return self - endpoints = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, + endpoints = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_endpoint.Endpoint, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -181,9 +185,13 @@ class UpdateEndpointRequest(proto.Message): `FieldMask `__. """ - endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) + endpoint = proto.Field(proto.MESSAGE, number=1, + message=gca_endpoint.Endpoint, + ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class DeleteEndpointRequest(proto.Message): @@ -236,8 +244,8 @@ class DeployModelRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - deployed_model = proto.Field( - proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, + deployed_model = proto.Field(proto.MESSAGE, number=2, + message=gca_endpoint.DeployedModel, ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -253,8 +261,8 @@ class DeployModelResponse(proto.Message): the Endpoint. """ - deployed_model = proto.Field( - proto.MESSAGE, number=1, message=gca_endpoint.DeployedModel, + deployed_model = proto.Field(proto.MESSAGE, number=1, + message=gca_endpoint.DeployedModel, ) @@ -267,8 +275,8 @@ class DeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -317,8 +325,8 @@ class UndeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/env_var.py b/google/cloud/aiplatform_v1beta1/types/env_var.py index 0d2c3769ff..1e1f279843 100644 --- a/google/cloud/aiplatform_v1beta1/types/env_var.py +++ b/google/cloud/aiplatform_v1beta1/types/env_var.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"EnvVar",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'EnvVar', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/event.py b/google/cloud/aiplatform_v1beta1/types/event.py new file mode 100644 index 0000000000..fedaf1e205 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/event.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Event', + }, +) + + +class Event(proto.Message): + r"""An edge describing the relationship between an Artifact and + an Execution in a lineage graph. + + Attributes: + artifact (str): + Required. The relative resource name of the + Artifact in the Event. + execution (str): + Output only. The relative resource name of + the Execution in the Event. + event_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time the Event occurred. + type_ (google.cloud.aiplatform_v1beta1.types.Event.Type): + Required. The type of the Event. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Event.LabelsEntry]): + The labels with user-defined metadata to + annotate Events. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. No more than 64 user labels can be + associated with one Event (System labels are + excluded). + + See https://goo.gl/xmQnxf for more information + and examples of labels. System reserved label + keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + """ + class Type(proto.Enum): + r"""Describes whether an Event's Artifact is the Execution's + input or output. + """ + TYPE_UNSPECIFIED = 0 + INPUT = 1 + OUTPUT = 2 + + artifact = proto.Field(proto.STRING, number=1) + + execution = proto.Field(proto.STRING, number=2) + + event_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) + + type_ = proto.Field(proto.ENUM, number=4, + enum=Type, + ) + + labels = proto.MapField(proto.STRING, proto.STRING, number=5) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/execution.py b/google/cloud/aiplatform_v1beta1/types/execution.py new file mode 100644 index 0000000000..f252dc1def --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/execution.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Execution', + }, +) + + +class Execution(proto.Message): + r"""Instance of a general execution. + + Attributes: + name (str): + Output only. The resource name of the + Execution. + display_name (str): + User provided display name of the Execution. + May be up to 128 Unicode characters. + state (google.cloud.aiplatform_v1beta1.types.Execution.State): + The state of this Execution. This is a + property of the Execution, and does not imply or + capture any ongoing process. This property is + managed by clients (such as AI Platform + Pipelines) and the system does not prescribe or + check the validity of state transitions. + etag (str): + An eTag used to perform consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Execution.LabelsEntry]): + The labels with user-defined metadata to organize your + Executions. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. No more than 64 user labels can be + associated with one Execution (System labels are excluded). + + See https://goo.gl/xmQnxf for more information and examples + of labels. System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. Following + system labels exist for each Execution: + + - "aiplatform.googleapis.com/schema_title": + + - output only, its value is the title of the Execution + schema provided either by [instance_schema_uri][] or + [instance_schema][]. + + - "aiplatform.googleapis.com/schema_version": + + - output only, its value is the schema version of the + Execution schema provided either by + [instance_schema_uri][] or [instance_schema][]. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Execution + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Execution + was last updated. + schema_title (str): + The title of the schema describing the + metadata. + Schema title and version is expected to be + registered in earlier Create Schema calls. And + both are used together as unique identifiers to + identify schemas within the local metadata + store. + schema_version (str): + The version of the schema in schema_name to use. + + Schema title and version is expected to be registered in + earlier Create Schema calls. And both are used together as + unique identifiers to identify schemas within the local + metadata store. + metadata (google.protobuf.struct_pb2.Struct): + Properties of the Execution. + description (str): + Description of the Execution + """ + class State(proto.Enum): + r"""Describes the state of the Execution.""" + STATE_UNSPECIFIED = 0 + NEW = 1 + RUNNING = 2 + COMPLETE = 3 + FAILED = 4 + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + state = proto.Field(proto.ENUM, number=6, + enum=State, + ) + + etag = proto.Field(proto.STRING, number=9) + + labels = proto.MapField(proto.STRING, proto.STRING, number=10) + + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) + + schema_title = proto.Field(proto.STRING, number=13) + + schema_version = proto.Field(proto.STRING, number=14) + + metadata = proto.Field(proto.MESSAGE, number=15, + message=struct.Struct, + ) + + description = proto.Field(proto.STRING, number=16) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index d9b48b60ab..e7980559cc 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -23,20 +23,20 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "Explanation", - "ModelExplanation", - "Attribution", - "ExplanationSpec", - "ExplanationParameters", - "SampledShapleyAttribution", - "IntegratedGradientsAttribution", - "XraiAttribution", - "SmoothGradConfig", - "FeatureNoiseSigma", - "ExplanationSpecOverride", - "ExplanationMetadataOverride", + 'Explanation', + 'ModelExplanation', + 'Attribution', + 'ExplanationSpec', + 'ExplanationParameters', + 'SampledShapleyAttribution', + 'IntegratedGradientsAttribution', + 'XraiAttribution', + 'SmoothGradConfig', + 'FeatureNoiseSigma', + 'ExplanationSpecOverride', + 'ExplanationMetadataOverride', }, ) @@ -73,7 +73,9 @@ class Explanation(proto.Message): in the same order as they appear in the output_indices. """ - attributions = proto.RepeatedField(proto.MESSAGE, number=1, message="Attribution",) + attributions = proto.RepeatedField(proto.MESSAGE, number=1, + message='Attribution', + ) class ModelExplanation(proto.Message): @@ -110,8 +112,8 @@ class ModelExplanation(proto.Message): is not populated. """ - mean_attributions = proto.RepeatedField( - proto.MESSAGE, number=1, message="Attribution", + mean_attributions = proto.RepeatedField(proto.MESSAGE, number=1, + message='Attribution', ) @@ -235,7 +237,9 @@ class Attribution(proto.Message): instance_output_value = proto.Field(proto.DOUBLE, number=2) - feature_attributions = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + feature_attributions = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) output_index = proto.RepeatedField(proto.INT32, number=4) @@ -258,10 +262,12 @@ class ExplanationSpec(proto.Message): input and output for explanation. """ - parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) + parameters = proto.Field(proto.MESSAGE, number=1, + message='ExplanationParameters', + ) - metadata = proto.Field( - proto.MESSAGE, number=2, message=explanation_metadata.ExplanationMetadata, + metadata = proto.Field(proto.MESSAGE, number=2, + message=explanation_metadata.ExplanationMetadata, ) @@ -318,24 +324,23 @@ class ExplanationParameters(proto.Message): (e,g, multi-class Models that predict multiple classes). """ - sampled_shapley_attribution = proto.Field( - proto.MESSAGE, number=1, oneof="method", message="SampledShapleyAttribution", + sampled_shapley_attribution = proto.Field(proto.MESSAGE, number=1, oneof='method', + message='SampledShapleyAttribution', ) - integrated_gradients_attribution = proto.Field( - proto.MESSAGE, - number=2, - oneof="method", - message="IntegratedGradientsAttribution", + integrated_gradients_attribution = proto.Field(proto.MESSAGE, number=2, oneof='method', + message='IntegratedGradientsAttribution', ) - xrai_attribution = proto.Field( - proto.MESSAGE, number=3, oneof="method", message="XraiAttribution", + xrai_attribution = proto.Field(proto.MESSAGE, number=3, oneof='method', + message='XraiAttribution', ) top_k = proto.Field(proto.INT32, number=4) - output_indices = proto.Field(proto.MESSAGE, number=5, message=struct.ListValue,) + output_indices = proto.Field(proto.MESSAGE, number=5, + message=struct.ListValue, + ) class SampledShapleyAttribution(proto.Message): @@ -382,8 +387,8 @@ class IntegratedGradientsAttribution(proto.Message): step_count = proto.Field(proto.INT32, number=1) - smooth_grad_config = proto.Field( - proto.MESSAGE, number=2, message="SmoothGradConfig", + smooth_grad_config = proto.Field(proto.MESSAGE, number=2, + message='SmoothGradConfig', ) @@ -416,8 +421,8 @@ class XraiAttribution(proto.Message): step_count = proto.Field(proto.INT32, number=1) - smooth_grad_config = proto.Field( - proto.MESSAGE, number=2, message="SmoothGradConfig", + smooth_grad_config = proto.Field(proto.MESSAGE, number=2, + message='SmoothGradConfig', ) @@ -462,13 +467,10 @@ class SmoothGradConfig(proto.Message): Valid range of its value is [1, 50]. Defaults to 3. """ - noise_sigma = proto.Field(proto.FLOAT, number=1, oneof="GradientNoiseSigma") + noise_sigma = proto.Field(proto.FLOAT, number=1, oneof='GradientNoiseSigma') - feature_noise_sigma = proto.Field( - proto.MESSAGE, - number=2, - oneof="GradientNoiseSigma", - message="FeatureNoiseSigma", + feature_noise_sigma = proto.Field(proto.MESSAGE, number=2, oneof='GradientNoiseSigma', + message='FeatureNoiseSigma', ) noisy_sample_count = proto.Field(proto.INT32, number=3) @@ -484,7 +486,6 @@ class FeatureNoiseSigma(proto.Message): Noise sigma per feature. No noise is added to features that are not set. """ - class NoiseSigmaForFeature(proto.Message): r"""Noise sigma for a single feature. @@ -506,8 +507,8 @@ class NoiseSigmaForFeature(proto.Message): sigma = proto.Field(proto.FLOAT, number=2) - noise_sigma = proto.RepeatedField( - proto.MESSAGE, number=1, message=NoiseSigmaForFeature, + noise_sigma = proto.RepeatedField(proto.MESSAGE, number=1, + message=NoiseSigmaForFeature, ) @@ -529,10 +530,12 @@ class ExplanationSpecOverride(proto.Message): specified, no metadata is overridden. """ - parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) + parameters = proto.Field(proto.MESSAGE, number=1, + message='ExplanationParameters', + ) - metadata = proto.Field( - proto.MESSAGE, number=2, message="ExplanationMetadataOverride", + metadata = proto.Field(proto.MESSAGE, number=2, + message='ExplanationMetadataOverride', ) @@ -553,7 +556,6 @@ class ExplanationMetadataOverride(proto.Message): here, the corresponding feature's input metadata is not overridden. """ - class InputMetadataOverride(proto.Message): r"""The [input metadata][google.cloud.aiplatform.v1beta1.ExplanationMetadata.InputMetadata] @@ -570,12 +572,12 @@ class InputMetadataOverride(proto.Message): overridden. """ - input_baselines = proto.RepeatedField( - proto.MESSAGE, number=1, message=struct.Value, + input_baselines = proto.RepeatedField(proto.MESSAGE, number=1, + message=struct.Value, ) - inputs = proto.MapField( - proto.STRING, proto.MESSAGE, number=1, message=InputMetadataOverride, + inputs = proto.MapField(proto.STRING, proto.MESSAGE, number=1, + message=InputMetadataOverride, ) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 69947e9b9e..79cb0925c4 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -22,7 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"ExplanationMetadata",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ExplanationMetadata', + }, ) @@ -70,7 +73,6 @@ class ExplanationMetadata(proto.Message): output URI will point to a location where the user only has a read access. """ - class InputMetadata(proto.Message): r"""Metadata of the input of a feature. @@ -158,7 +160,6 @@ class InputMetadata(proto.Message): featureAttributions][Attribution.feature_attributions], keyed by the group name. """ - class Encoding(proto.Enum): r"""Defines how the feature is encoded to [encoded_tensor][]. Defaults to IDENTITY. @@ -250,7 +251,6 @@ class Visualization(proto.Message): makes it difficult to view the visualization. Defaults to NONE. """ - class Type(proto.Enum): r"""Type of the image visualization. Only applicable to [Integrated Gradients attribution] @@ -287,50 +287,40 @@ class OverlayType(proto.Enum): GRAYSCALE = 3 MASK_BLACK = 4 - type_ = proto.Field( - proto.ENUM, - number=1, - enum="ExplanationMetadata.InputMetadata.Visualization.Type", + type_ = proto.Field(proto.ENUM, number=1, + enum='ExplanationMetadata.InputMetadata.Visualization.Type', ) - polarity = proto.Field( - proto.ENUM, - number=2, - enum="ExplanationMetadata.InputMetadata.Visualization.Polarity", + polarity = proto.Field(proto.ENUM, number=2, + enum='ExplanationMetadata.InputMetadata.Visualization.Polarity', ) - color_map = proto.Field( - proto.ENUM, - number=3, - enum="ExplanationMetadata.InputMetadata.Visualization.ColorMap", + color_map = proto.Field(proto.ENUM, number=3, + enum='ExplanationMetadata.InputMetadata.Visualization.ColorMap', ) clip_percent_upperbound = proto.Field(proto.FLOAT, number=4) clip_percent_lowerbound = proto.Field(proto.FLOAT, number=5) - overlay_type = proto.Field( - proto.ENUM, - number=6, - enum="ExplanationMetadata.InputMetadata.Visualization.OverlayType", + overlay_type = proto.Field(proto.ENUM, number=6, + enum='ExplanationMetadata.InputMetadata.Visualization.OverlayType', ) - input_baselines = proto.RepeatedField( - proto.MESSAGE, number=1, message=struct.Value, + input_baselines = proto.RepeatedField(proto.MESSAGE, number=1, + message=struct.Value, ) input_tensor_name = proto.Field(proto.STRING, number=2) - encoding = proto.Field( - proto.ENUM, number=3, enum="ExplanationMetadata.InputMetadata.Encoding", + encoding = proto.Field(proto.ENUM, number=3, + enum='ExplanationMetadata.InputMetadata.Encoding', ) modality = proto.Field(proto.STRING, number=4) - feature_value_domain = proto.Field( - proto.MESSAGE, - number=5, - message="ExplanationMetadata.InputMetadata.FeatureValueDomain", + feature_value_domain = proto.Field(proto.MESSAGE, number=5, + message='ExplanationMetadata.InputMetadata.FeatureValueDomain', ) indices_tensor_name = proto.Field(proto.STRING, number=6) @@ -341,14 +331,12 @@ class OverlayType(proto.Enum): encoded_tensor_name = proto.Field(proto.STRING, number=9) - encoded_baselines = proto.RepeatedField( - proto.MESSAGE, number=10, message=struct.Value, + encoded_baselines = proto.RepeatedField(proto.MESSAGE, number=10, + message=struct.Value, ) - visualization = proto.Field( - proto.MESSAGE, - number=11, - message="ExplanationMetadata.InputMetadata.Visualization", + visualization = proto.Field(proto.MESSAGE, number=11, + message='ExplanationMetadata.InputMetadata.Visualization', ) group_name = proto.Field(proto.STRING, number=12) @@ -390,22 +378,20 @@ class OutputMetadata(proto.Message): for Tensorflow. """ - index_display_name_mapping = proto.Field( - proto.MESSAGE, number=1, oneof="display_name_mapping", message=struct.Value, + index_display_name_mapping = proto.Field(proto.MESSAGE, number=1, oneof='display_name_mapping', + message=struct.Value, ) - display_name_mapping_key = proto.Field( - proto.STRING, number=2, oneof="display_name_mapping" - ) + display_name_mapping_key = proto.Field(proto.STRING, number=2, oneof='display_name_mapping') output_tensor_name = proto.Field(proto.STRING, number=3) - inputs = proto.MapField( - proto.STRING, proto.MESSAGE, number=1, message=InputMetadata, + inputs = proto.MapField(proto.STRING, proto.MESSAGE, number=1, + message=InputMetadata, ) - outputs = proto.MapField( - proto.STRING, proto.MESSAGE, number=2, message=OutputMetadata, + outputs = proto.MapField(proto.STRING, proto.MESSAGE, number=2, + message=OutputMetadata, ) feature_attributions_schema_uri = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py b/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py new file mode 100644 index 0000000000..4dce0cda0e --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'FeatureStatsAnomaly', + }, +) + + +class FeatureStatsAnomaly(proto.Message): + r"""Stats and Anomaly generated at specific timestamp for specific + Feature. The start_time and end_time are used to define the time + range of the dataset that current stats belongs to, e.g. prediction + traffic is bucketed into prediction datasets by time window. If the + Dataset is not defined by time window, start_time = end_time. + Timestamp of the stats and anomalies always refers to end_time. Raw + stats and anomalies are stored in stats_uri or anomaly_uri in the + tensorflow defined protos. Field data_stats contains almost + identical information with the raw stats in AI Platform defined + proto, for UI to display. + + Attributes: + score (float): + Feature importance score, only populated when cross-feature + monitoring is enabled. For now only used to represent + feature attribution score within range [0, 1] for + ``ModelDeploymentMonitoringObjectiveType.FEATURE_ATTRIBUTION_SKEW`` + and + ``ModelDeploymentMonitoringObjectiveType.FEATURE_ATTRIBUTION_DRIFT``. + stats_uri (str): + Path of the stats file for current feature values in Cloud + Storage bucket. Format: + gs:////stats. Example: + gs://monitoring_bucket/feature_name/stats. Stats are stored + as binary format with Protobuf message + `tensorflow.metadata.v0.FeatureNameStatistics `__. + anomaly_uri (str): + Path of the anomaly file for current feature values in Cloud + Storage bucket. Format: + gs:////anomalies. Example: + gs://monitoring_bucket/feature_name/anomalies. Stats are + stored as binary format with Protobuf message Anoamlies are + stored as binary format with Protobuf message + [tensorflow.metadata.v0.AnomalyInfo] + (https://github.com/tensorflow/metadata/blob/master/tensorflow_metadata/proto/v0/anomalies.proto). + distribution_deviation (float): + Deviation from the current stats to baseline + stats. 1. For categorical feature, the + distribution distance is calculated by + L-inifinity norm. + 2. For numerical feature, the distribution + distance is calculated by Jensen–Shannon + divergence. + anomaly_detection_threshold (float): + This is the threshold used when detecting anomalies. The + threshold can be changed by user, so this one might be + different from + ``ThresholdConfig.value``. + start_time (google.protobuf.timestamp_pb2.Timestamp): + The start timestamp of window where stats + were generated. + end_time (google.protobuf.timestamp_pb2.Timestamp): + The end timestamp of window where stats were + generated. + """ + + score = proto.Field(proto.DOUBLE, number=1) + + stats_uri = proto.Field(proto.STRING, number=3) + + anomaly_uri = proto.Field(proto.STRING, number=4) + + distribution_deviation = proto.Field(proto.DOUBLE, number=5) + + anomaly_detection_threshold = proto.Field(proto.DOUBLE, number=9) + + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) + + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py index 55978a409e..fbf5262553 100644 --- a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py @@ -27,7 +27,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"HyperparameterTuningJob",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'HyperparameterTuningJob', + }, ) @@ -106,7 +109,9 @@ class HyperparameterTuningJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) + study_spec = proto.Field(proto.MESSAGE, number=4, + message=study.StudySpec, + ) max_trial_count = proto.Field(proto.INT32, number=5) @@ -114,28 +119,42 @@ class HyperparameterTuningJob(proto.Message): max_failed_trial_count = proto.Field(proto.INT32, number=7) - trial_job_spec = proto.Field( - proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, + trial_job_spec = proto.Field(proto.MESSAGE, number=8, + message=custom_job.CustomJobSpec, ) - trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) + trials = proto.RepeatedField(proto.MESSAGE, number=9, + message=study.Trial, + ) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + state = proto.Field(proto.ENUM, number=10, + enum=job_state.JobState, + ) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) - error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=15, + message=status.Status, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=16) - encryption_spec = proto.Field( - proto.MESSAGE, number=17, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=17, + message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index 3a177dcf9b..0d938b4628 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -19,13 +19,13 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "GcsSource", - "GcsDestination", - "BigQuerySource", - "BigQueryDestination", - "ContainerRegistryDestination", + 'GcsSource', + 'GcsDestination', + 'BigQuerySource', + 'BigQueryDestination', + 'ContainerRegistryDestination', }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index 514ca12f7a..bc8b117832 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -18,46 +18,54 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import operation from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateCustomJobRequest", - "GetCustomJobRequest", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "DeleteCustomJobRequest", - "CancelCustomJobRequest", - "CreateDataLabelingJobRequest", - "GetDataLabelingJobRequest", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "DeleteDataLabelingJobRequest", - "CancelDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "GetHyperparameterTuningJobRequest", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "DeleteHyperparameterTuningJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "GetBatchPredictionJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "DeleteBatchPredictionJobRequest", - "CancelBatchPredictionJobRequest", + 'CreateCustomJobRequest', + 'GetCustomJobRequest', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'DeleteCustomJobRequest', + 'CancelCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'GetDataLabelingJobRequest', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'DeleteDataLabelingJobRequest', + 'CancelDataLabelingJobRequest', + 'CreateHyperparameterTuningJobRequest', + 'GetHyperparameterTuningJobRequest', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'DeleteHyperparameterTuningJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CreateBatchPredictionJobRequest', + 'GetBatchPredictionJobRequest', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'DeleteBatchPredictionJobRequest', + 'CancelBatchPredictionJobRequest', + 'CreateModelDeploymentMonitoringJobRequest', + 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', + 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', + 'GetModelDeploymentMonitoringJobRequest', + 'ListModelDeploymentMonitoringJobsRequest', + 'ListModelDeploymentMonitoringJobsResponse', + 'UpdateModelDeploymentMonitoringJobRequest', + 'DeleteModelDeploymentMonitoringJobRequest', + 'PauseModelDeploymentMonitoringJobRequest', + 'ResumeModelDeploymentMonitoringJobRequest', + 'UpdateModelDeploymentMonitoringJobOperationMetadata', }, ) @@ -77,7 +85,9 @@ class CreateCustomJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) + custom_job = proto.Field(proto.MESSAGE, number=2, + message=gca_custom_job.CustomJob, + ) class GetCustomJobRequest(proto.Message): @@ -140,7 +150,9 @@ class ListCustomJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListCustomJobsResponse(proto.Message): @@ -160,8 +172,8 @@ class ListCustomJobsResponse(proto.Message): def raw_page(self): return self - custom_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, + custom_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_custom_job.CustomJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -208,8 +220,8 @@ class CreateDataLabelingJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field( - proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, + data_labeling_job = proto.Field(proto.MESSAGE, number=2, + message=gca_data_labeling_job.DataLabelingJob, ) @@ -274,7 +286,9 @@ class ListDataLabelingJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) order_by = proto.Field(proto.STRING, number=6) @@ -295,8 +309,8 @@ class ListDataLabelingJobsResponse(proto.Message): def raw_page(self): return self - data_labeling_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, + data_labeling_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_data_labeling_job.DataLabelingJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -345,9 +359,7 @@ class CreateHyperparameterTuningJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - hyperparameter_tuning_job = proto.Field( - proto.MESSAGE, - number=2, + hyperparameter_tuning_job = proto.Field(proto.MESSAGE, number=2, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -413,7 +425,9 @@ class ListHyperparameterTuningJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListHyperparameterTuningJobsResponse(proto.Message): @@ -435,9 +449,7 @@ class ListHyperparameterTuningJobsResponse(proto.Message): def raw_page(self): return self - hyperparameter_tuning_jobs = proto.RepeatedField( - proto.MESSAGE, - number=1, + hyperparameter_tuning_jobs = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -487,8 +499,8 @@ class CreateBatchPredictionJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - batch_prediction_job = proto.Field( - proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_job = proto.Field(proto.MESSAGE, number=2, + message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -555,7 +567,9 @@ class ListBatchPredictionJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListBatchPredictionJobsResponse(proto.Message): @@ -576,8 +590,8 @@ class ListBatchPredictionJobsResponse(proto.Message): def raw_page(self): return self - batch_prediction_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_batch_prediction_job.BatchPredictionJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -611,4 +625,278 @@ class CancelBatchPredictionJobRequest(proto.Message): name = proto.Field(proto.STRING, number=1) +class CreateModelDeploymentMonitoringJobRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.CreateModelDeploymentMonitoringJob][]. + + Attributes: + parent (str): + Required. The parent of the ModelDeploymentMonitoringJob. + Format: ``projects/{project}/locations/{location}`` + model_deployment_monitoring_job (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob): + Required. The ModelDeploymentMonitoringJob to + create + """ + + parent = proto.Field(proto.STRING, number=1) + + model_deployment_monitoring_job = proto.Field(proto.MESSAGE, number=2, + message=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + ) + + +class SearchModelDeploymentMonitoringStatsAnomaliesRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + + Attributes: + model_deployment_monitoring_job (str): + Required. ModelDeploymentMonitoring Job resource name. + Format: + \`projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job} + deployed_model_id (str): + Required. The DeployedModel ID of the + [google.cloud.aiplatform.master.ModelDeploymentMonitoringObjectiveConfig.deployed_model_id]. + feature_display_name (str): + The feature display name. If specified, only return the + stats belonging to this feature. Format: + ``ModelMonitoringStatsAnomalies.FeatureHistoricStatsAnomalies.feature_display_name``, + example: "user_destination". + objectives (Sequence[google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesRequest.StatsAnomaliesObjective]): + Required. Objectives of the stats to + retrieve. + page_size (int): + The standard list page size. + page_token (str): + A page token received from a previous + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][] + call. + start_time (google.protobuf.timestamp_pb2.Timestamp): + The earliest timestamp of stats being + generated. If not set, indicates fetching stats + till the earliest possible one. + end_time (google.protobuf.timestamp_pb2.Timestamp): + The latest timestamp of stats being + generated. If not set, indicates feching stats + till the latest possible one. + """ + class StatsAnomaliesObjective(proto.Message): + r"""Stats requested for specific objective. + + Attributes: + type_ (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringObjectiveType): + + top_feature_count (int): + If set, all attribution scores between + ``SearchModelDeploymentMonitoringStatsAnomaliesRequest.start_time`` + and + ``SearchModelDeploymentMonitoringStatsAnomaliesRequest.end_time`` + are fetched, and page token doesn't take affect in this + case. Only used to retrieve attribution score for the top + Features which has the highest attribution score in the + latest monitoring run. + """ + + type_ = proto.Field(proto.ENUM, number=1, + enum=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringObjectiveType, + ) + + top_feature_count = proto.Field(proto.INT32, number=4) + + model_deployment_monitoring_job = proto.Field(proto.STRING, number=1) + + deployed_model_id = proto.Field(proto.STRING, number=2) + + feature_display_name = proto.Field(proto.STRING, number=3) + + objectives = proto.RepeatedField(proto.MESSAGE, number=4, + message=StatsAnomaliesObjective, + ) + + page_size = proto.Field(proto.INT32, number=5) + + page_token = proto.Field(proto.STRING, number=6) + + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) + + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) + + +class SearchModelDeploymentMonitoringStatsAnomaliesResponse(proto.Message): + r"""Response message for + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + + Attributes: + monitoring_stats (Sequence[google.cloud.aiplatform_v1beta1.types.ModelMonitoringStatsAnomalies]): + Stats retrieved for requested objectives. There are at most + 1000 + [ModelMonitoringStatsAnomalies.feature_stats.prediction_stats][] + in the response. + next_page_token (str): + The page token that can be used by the next + [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][] + call. + """ + + @property + def raw_page(self): + return self + + monitoring_stats = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class GetModelDeploymentMonitoringJobRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.GetModelDeploymentMonitoringJob][]. + + Attributes: + name (str): + Required. The resource name of the + ModelDeploymentMonitoringJob. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListModelDeploymentMonitoringJobsRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + + Attributes: + parent (str): + Required. The parent of the ModelDeploymentMonitoringJob. + Format: ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) + + +class ListModelDeploymentMonitoringJobsResponse(proto.Message): + r"""Response message for + [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + + Attributes: + model_deployment_monitoring_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob]): + A list of ModelDeploymentMonitoringJobs that + matches the specified filter in the request. + next_page_token (str): + The standard List next-page token. + """ + + @property + def raw_page(self): + return self + + model_deployment_monitoring_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateModelDeploymentMonitoringJobRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + + Attributes: + model_deployment_monitoring_job (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob): + Required. The model monitoring configuration + which replaces the resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the + resource. + """ + + model_deployment_monitoring_job = proto.Field(proto.MESSAGE, number=1, + message=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + +class DeleteModelDeploymentMonitoringJobRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.DeleteModelDeploymentMonitoringJob][]. + + Attributes: + name (str): + Required. The resource name of the model monitoring job to + delete. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class PauseModelDeploymentMonitoringJobRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.PauseModelDeploymentMonitoringJob][]. + + Attributes: + name (str): + Required. The resource name of the + ModelDeploymentMonitoringJob to pause. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ResumeModelDeploymentMonitoringJobRequest(proto.Message): + r"""Request message for + [ModelDeploymentMonitoringJobService.ResumeModelDeploymentMonitoringJob][]. + + Attributes: + name (str): + Required. The resource name of the + ModelDeploymentMonitoringJob to resume. Format: + ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class UpdateModelDeploymentMonitoringJobOperationMetadata(proto.Message): + r"""Runtime operation information for + [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/job_state.py b/google/cloud/aiplatform_v1beta1/types/job_state.py index b77947cc9a..6d199390db 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_state.py +++ b/google/cloud/aiplatform_v1beta1/types/job_state.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"JobState",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'JobState', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py b/google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py new file mode 100644 index 0000000000..ba291eb8f6 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import event +from google.cloud.aiplatform_v1beta1.types import execution + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'LineageSubgraph', + }, +) + + +class LineageSubgraph(proto.Message): + r"""A subgraph of the overall lineage graph. Event edges connect + Artifact and Execution nodes. + + Attributes: + artifacts (Sequence[google.cloud.aiplatform_v1beta1.types.Artifact]): + The Artifact nodes in the subgraph. + executions (Sequence[google.cloud.aiplatform_v1beta1.types.Execution]): + The Execution nodes in the subgraph. + events (Sequence[google.cloud.aiplatform_v1beta1.types.Event]): + The Event edges between Artifacts and + Executions in the subgraph. + """ + + artifacts = proto.RepeatedField(proto.MESSAGE, number=1, + message=artifact.Artifact, + ) + + executions = proto.RepeatedField(proto.MESSAGE, number=2, + message=execution.Execution, + ) + + events = proto.RepeatedField(proto.MESSAGE, number=3, + message=event.Event, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index c791354c58..48b2ad18c4 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -18,21 +18,19 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - accelerator_type as gca_accelerator_type, -) +from google.cloud.aiplatform_v1beta1.types import accelerator_type as gca_accelerator_type __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "MachineSpec", - "DedicatedResources", - "AutomaticResources", - "BatchDedicatedResources", - "ResourcesConsumed", - "DiskSpec", - "AutoscalingMetricSpec", + 'MachineSpec', + 'DedicatedResources', + 'AutomaticResources', + 'BatchDedicatedResources', + 'ResourcesConsumed', + 'DiskSpec', + 'AutoscalingMetricSpec', }, ) @@ -67,8 +65,8 @@ class MachineSpec(proto.Message): machine_type = proto.Field(proto.STRING, number=1) - accelerator_type = proto.Field( - proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, + accelerator_type = proto.Field(proto.ENUM, number=2, + enum=gca_accelerator_type.AcceleratorType, ) accelerator_count = proto.Field(proto.INT32, number=3) @@ -135,14 +133,16 @@ class DedicatedResources(proto.Message): to ``80``. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + machine_spec = proto.Field(proto.MESSAGE, number=1, + message='MachineSpec', + ) min_replica_count = proto.Field(proto.INT32, number=2) max_replica_count = proto.Field(proto.INT32, number=3) - autoscaling_metric_specs = proto.RepeatedField( - proto.MESSAGE, number=4, message="AutoscalingMetricSpec", + autoscaling_metric_specs = proto.RepeatedField(proto.MESSAGE, number=4, + message='AutoscalingMetricSpec', ) @@ -203,7 +203,9 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + machine_spec = proto.Field(proto.MESSAGE, number=1, + message='MachineSpec', + ) starting_replica_count = proto.Field(proto.INT32, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py index 7a467d5069..da5c4d38ab 100644 --- a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py +++ b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"ManualBatchTuningParameters",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ManualBatchTuningParameters', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_schema.py b/google/cloud/aiplatform_v1beta1/types/metadata_schema.py new file mode 100644 index 0000000000..7c690a1b94 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/metadata_schema.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'MetadataSchema', + }, +) + + +class MetadataSchema(proto.Message): + r"""Instance of a general MetadataSchema. + + Attributes: + name (str): + Output only. The resource name of the + MetadataSchema. + schema_version (str): + The version of the MetadataSchema. The version's format must + match the following regular expression: + ``^[0-9]+[.][0-9]+[.][0-9]+$``, which would allow to + order/compare different versions.Example: 1.0.0, 1.0.1, etc. + schema (str): + Required. The raw YAML string representation of the + MetadataSchema. The combination of [MetadataSchema.version] + and the schema name given by ``title`` in + [MetadataSchema.schema] must be unique within a + MetadataStore. + + The schema is defined as an OpenAPI 3.0.2 `MetadataSchema + Object `__ + schema_type (google.cloud.aiplatform_v1beta1.types.MetadataSchema.MetadataSchemaType): + The type of the MetadataSchema. This is a + property that identifies which metadata types + will use the MetadataSchema. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + MetadataSchema was created. + description (str): + Description of the Metadata Schema + """ + class MetadataSchemaType(proto.Enum): + r"""Describes the type of the MetadataSchema.""" + METADATA_SCHEMA_TYPE_UNSPECIFIED = 0 + ARTIFACT_TYPE = 1 + EXECUTION_TYPE = 2 + CONTEXT_TYPE = 3 + + name = proto.Field(proto.STRING, number=1) + + schema_version = proto.Field(proto.STRING, number=2) + + schema = proto.Field(proto.STRING, number=3) + + schema_type = proto.Field(proto.ENUM, number=4, + enum=MetadataSchemaType, + ) + + create_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) + + description = proto.Field(proto.STRING, number=6) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_service.py b/google/cloud/aiplatform_v1beta1/types/metadata_service.py new file mode 100644 index 0000000000..3777316237 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/metadata_service.py @@ -0,0 +1,900 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact +from google.cloud.aiplatform_v1beta1.types import context as gca_context +from google.cloud.aiplatform_v1beta1.types import event +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import metadata_schema as gca_metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_store as gca_metadata_store +from google.cloud.aiplatform_v1beta1.types import operation +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'CreateMetadataStoreRequest', + 'CreateMetadataStoreOperationMetadata', + 'GetMetadataStoreRequest', + 'ListMetadataStoresRequest', + 'ListMetadataStoresResponse', + 'DeleteMetadataStoreRequest', + 'DeleteMetadataStoreOperationMetadata', + 'CreateArtifactRequest', + 'GetArtifactRequest', + 'ListArtifactsRequest', + 'ListArtifactsResponse', + 'UpdateArtifactRequest', + 'CreateContextRequest', + 'GetContextRequest', + 'ListContextsRequest', + 'ListContextsResponse', + 'UpdateContextRequest', + 'DeleteContextRequest', + 'AddContextArtifactsAndExecutionsRequest', + 'AddContextArtifactsAndExecutionsResponse', + 'AddContextChildrenRequest', + 'AddContextChildrenResponse', + 'QueryContextLineageSubgraphRequest', + 'CreateExecutionRequest', + 'GetExecutionRequest', + 'ListExecutionsRequest', + 'ListExecutionsResponse', + 'UpdateExecutionRequest', + 'AddExecutionEventsRequest', + 'AddExecutionEventsResponse', + 'QueryExecutionInputsAndOutputsRequest', + 'CreateMetadataSchemaRequest', + 'GetMetadataSchemaRequest', + 'ListMetadataSchemasRequest', + 'ListMetadataSchemasResponse', + }, +) + + +class CreateMetadataStoreRequest(proto.Message): + r"""Request message for + ``MetadataService.CreateMetadataStore``. + + Attributes: + parent (str): + Required. The resource name of the Location + where the MetadataStore should be created. + Format: projects/{project}/locations/{location}/ + metadata_store (google.cloud.aiplatform_v1beta1.types.MetadataStore): + Required. The MetadataStore to create. + metadata_store_id (str): + The {metadatastore} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + If not provided, the MetadataStore's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be unique + across all MetadataStores in the parent Location. (Otherwise + the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the preexisting + MetadataStore.) + """ + + parent = proto.Field(proto.STRING, number=1) + + metadata_store = proto.Field(proto.MESSAGE, number=2, + message=gca_metadata_store.MetadataStore, + ) + + metadata_store_id = proto.Field(proto.STRING, number=3) + + +class CreateMetadataStoreOperationMetadata(proto.Message): + r"""Details of operations that perform + ``MetadataService.CreateMetadataStore``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for creating a + MetadataStore. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class GetMetadataStoreRequest(proto.Message): + r"""Request message for + ``MetadataService.GetMetadataStore``. + + Attributes: + name (str): + Required. The resource name of the + MetadataStore to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListMetadataStoresRequest(proto.Message): + r"""Request message for + ``MetadataService.ListMetadataStores``. + + Attributes: + parent (str): + Required. The Location whose MetadataStores + should be listed. Format: + projects/{project}/locations/{location} + page_size (int): + The maximum number of Metadata Stores to + return. The service may return fewer. + Must be in range 1-1000, inclusive. Defaults to + 100. + page_token (str): + A page token, received from a previous + ``MetadataService.ListMetadataStores`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other provided parameters must match + the call that provided the page token. (Otherwise the + request will fail with INVALID_ARGUMENT error.) + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + +class ListMetadataStoresResponse(proto.Message): + r"""Response message for + ``MetadataService.ListMetadataStores``. + + Attributes: + metadata_stores (Sequence[google.cloud.aiplatform_v1beta1.types.MetadataStore]): + The MetadataStores found for the Location. + next_page_token (str): + A token, which can be sent as + [MetadataService.ListMetadataStores.page_token][] to + retrieve the next page. If this field is not populated, + there are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + metadata_stores = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_metadata_store.MetadataStore, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteMetadataStoreRequest(proto.Message): + r"""Request message for + ``MetadataService.DeleteMetadataStore``. + + Attributes: + name (str): + Required. The resource name of the + MetadataStore to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + force (bool): + If set to true, any child resources of this MetadataStore + will be deleted. (Otherwise, the request will fail with a + FAILED_PRECONDITION error if the MetadataStore has any child + resources.) + """ + + name = proto.Field(proto.STRING, number=1) + + force = proto.Field(proto.BOOL, number=2) + + +class DeleteMetadataStoreOperationMetadata(proto.Message): + r"""Details of operations that perform + ``MetadataService.DeleteMetadataStore``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for deleting a + MetadataStore. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class CreateArtifactRequest(proto.Message): + r"""Request message for + ``MetadataService.CreateArtifact``. + + Attributes: + parent (str): + Required. The resource name of the + MetadataStore where the Artifact should be + created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + artifact (google.cloud.aiplatform_v1beta1.types.Artifact): + Required. The Artifact to create. + artifact_id (str): + The {artifact} portion of the resource name with the format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + If not provided, the Artifact's ID will be a UUID generated + by the service. Must be 4-128 characters in length. Valid + characters are /[a-z][0-9]-/. Must be unique across all + Artifacts in the parent MetadataStore. (Otherwise the + request will fail with ALREADY_EXISTS, or PERMISSION_DENIED + if the caller can't view the preexisting Artifact.) + """ + + parent = proto.Field(proto.STRING, number=1) + + artifact = proto.Field(proto.MESSAGE, number=2, + message=gca_artifact.Artifact, + ) + + artifact_id = proto.Field(proto.STRING, number=3) + + +class GetArtifactRequest(proto.Message): + r"""Request message for + ``MetadataService.GetArtifact``. + + Attributes: + name (str): + Required. The resource name of the Artifact + to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListArtifactsRequest(proto.Message): + r"""Request message for + ``MetadataService.ListArtifacts``. + + Attributes: + parent (str): + Required. The MetadataStore whose Artifacts + should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + page_size (int): + The maximum number of Artifacts to return. + The service may return fewer. Must be in range + 1-1000, inclusive. Defaults to 100. + page_token (str): + A page token, received from a previous + ``MetadataService.ListArtifacts`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other provided parameters must match + the call that provided the page token. (Otherwise the + request will fail with INVALID_ARGUMENT error.) + filter (str): + A query to filter available Artifacts for + matching results. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + filter = proto.Field(proto.STRING, number=4) + + +class ListArtifactsResponse(proto.Message): + r"""Response message for + ``MetadataService.ListArtifacts``. + + Attributes: + artifacts (Sequence[google.cloud.aiplatform_v1beta1.types.Artifact]): + The Artifacts retrieved from the + MetadataStore. + next_page_token (str): + A token, which can be sent as + [MetadataService.ListArtifacts.page_token][] to retrieve the + next page. If this field is not populated, there are no + subsequent pages. + """ + + @property + def raw_page(self): + return self + + artifacts = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_artifact.Artifact, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateArtifactRequest(proto.Message): + r"""Request message for + ``MetadataService.UpdateArtifact``. + + Attributes: + artifact (google.cloud.aiplatform_v1beta1.types.Artifact): + Required. The Artifact containing updates. The Artifact's + ``Artifact.name`` + field is used to identify the Artifact to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. A FieldMask indicating which fields + should be updated. + allow_missing (bool): + If set to true, and the + ``Artifact`` is not + found, a new + ``Artifact`` will be + created. In this situation, ``update_mask`` is ignored. + """ + + artifact = proto.Field(proto.MESSAGE, number=1, + message=gca_artifact.Artifact, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + allow_missing = proto.Field(proto.BOOL, number=3) + + +class CreateContextRequest(proto.Message): + r"""Request message for + ``MetadataService.CreateContext``. + + Attributes: + parent (str): + Required. The resource name of the + MetadataStore where the Context should be + created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + context (google.cloud.aiplatform_v1beta1.types.Context): + Required. The Context to create. + context_id (str): + The {context} portion of the resource name with the format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + If not provided, the Context's ID will be a UUID generated + by the service. Must be 4-128 characters in length. Valid + characters are /[a-z][0-9]-/. Must be unique across all + Contexts in the parent MetadataStore. (Otherwise the request + will fail with ALREADY_EXISTS, or PERMISSION_DENIED if the + caller can't view the preexisting Context.) + """ + + parent = proto.Field(proto.STRING, number=1) + + context = proto.Field(proto.MESSAGE, number=2, + message=gca_context.Context, + ) + + context_id = proto.Field(proto.STRING, number=3) + + +class GetContextRequest(proto.Message): + r"""Request message for + ``MetadataService.GetContext``. + + Attributes: + name (str): + Required. The resource name of the Context to + retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListContextsRequest(proto.Message): + r"""Request message for + ``MetadataService.ListContexts`` + + Attributes: + parent (str): + Required. The MetadataStore whose Contexts + should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + page_size (int): + The maximum number of Contexts to return. The + service may return fewer. Must be in range + 1-1000, inclusive. Defaults to 100. + page_token (str): + A page token, received from a previous + ``MetadataService.ListContexts`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other provided parameters must match + the call that provided the page token. (Otherwise the + request will fail with INVALID_ARGUMENT error.) + filter (str): + A query to filter available Contexts for + matching results. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + filter = proto.Field(proto.STRING, number=4) + + +class ListContextsResponse(proto.Message): + r"""Response message for + ``MetadataService.ListContexts``. + + Attributes: + contexts (Sequence[google.cloud.aiplatform_v1beta1.types.Context]): + The Contexts retrieved from the + MetadataStore. + next_page_token (str): + A token, which can be sent as + [MetadataService.ListContexts.page_token][] to retrieve the + next page. If this field is not populated, there are no + subsequent pages. + """ + + @property + def raw_page(self): + return self + + contexts = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_context.Context, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateContextRequest(proto.Message): + r"""Request message for + ``MetadataService.UpdateContext``. + + Attributes: + context (google.cloud.aiplatform_v1beta1.types.Context): + Required. The Context containing updates. The Context's + ``Context.name`` + field is used to identify the Context to be updated. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. A FieldMask indicating which fields + should be updated. + allow_missing (bool): + If set to true, and the + ``Context`` is not + found, a new + ``Context`` will be + created. In this situation, ``update_mask`` is ignored. + """ + + context = proto.Field(proto.MESSAGE, number=1, + message=gca_context.Context, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + allow_missing = proto.Field(proto.BOOL, number=3) + + +class DeleteContextRequest(proto.Message): + r"""Request message for + ``MetadataService.DeleteContext``. + + Attributes: + name (str): + Required. The resource name of the Context to + retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + force (bool): + If set to true, any child resources of this Context will be + deleted. (Otherwise, the request will fail with a + FAILED_PRECONDITION error if the Context has any child + resources, such as another Context, Artifact, or Execution). + """ + + name = proto.Field(proto.STRING, number=1) + + force = proto.Field(proto.BOOL, number=2) + + +class AddContextArtifactsAndExecutionsRequest(proto.Message): + r"""Request message for + ``MetadataService.AddContextArtifactsAndExecutions``. + + Attributes: + context (str): + Required. The resource name of the Context + that the Artifacts and Executions belong to. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + artifacts (Sequence[str]): + The resource names of the Artifacts to + attribute to the Context. + executions (Sequence[str]): + The resource names of the Executions to + associate with the Context. + """ + + context = proto.Field(proto.STRING, number=1) + + artifacts = proto.RepeatedField(proto.STRING, number=2) + + executions = proto.RepeatedField(proto.STRING, number=3) + + +class AddContextArtifactsAndExecutionsResponse(proto.Message): + r"""Response message for + ``MetadataService.AddContextArtifactsAndExecutions``. + """ + + +class AddContextChildrenRequest(proto.Message): + r"""Request message for + ``MetadataService.AddContextChildren``. + + Attributes: + context (str): + Required. The resource name of the parent + Context. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + child_contexts (Sequence[str]): + The resource names of the child Contexts. + """ + + context = proto.Field(proto.STRING, number=1) + + child_contexts = proto.RepeatedField(proto.STRING, number=2) + + +class AddContextChildrenResponse(proto.Message): + r"""Response message for + ``MetadataService.AddContextChildren``. + """ + + +class QueryContextLineageSubgraphRequest(proto.Message): + r"""Request message for + ``MetadataService.QueryContextLineageSubgraph``. + + Attributes: + context (str): + Required. The resource name of the Context whose Artifacts + and Executions should be retrieved as a LineageSubgraph. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} + + The request may error with FAILED_PRECONDITION if the number + of Artifacts, the number of Executions, or the number of + Events that would be returned for the Context exceeds 1000. + """ + + context = proto.Field(proto.STRING, number=1) + + +class CreateExecutionRequest(proto.Message): + r"""Request message for + ``MetadataService.CreateExecution``. + + Attributes: + parent (str): + Required. The resource name of the + MetadataStore where the Execution should be + created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + execution (google.cloud.aiplatform_v1beta1.types.Execution): + Required. The Execution to create. + execution_id (str): + The {execution} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + If not provided, the Execution's ID will be a UUID generated + by the service. Must be 4-128 characters in length. Valid + characters are /[a-z][0-9]-/. Must be unique across all + Executions in the parent MetadataStore. (Otherwise the + request will fail with ALREADY_EXISTS, or PERMISSION_DENIED + if the caller can't view the preexisting Execution.) + """ + + parent = proto.Field(proto.STRING, number=1) + + execution = proto.Field(proto.MESSAGE, number=2, + message=gca_execution.Execution, + ) + + execution_id = proto.Field(proto.STRING, number=3) + + +class GetExecutionRequest(proto.Message): + r"""Request message for + ``MetadataService.GetExecution``. + + Attributes: + name (str): + Required. The resource name of the Execution + to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListExecutionsRequest(proto.Message): + r"""Request message for + ``MetadataService.ListExecutions``. + + Attributes: + parent (str): + Required. The MetadataStore whose Executions + should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + page_size (int): + The maximum number of Executions to return. + The service may return fewer. Must be in range + 1-1000, inclusive. Defaults to 100. + page_token (str): + A page token, received from a previous + ``MetadataService.ListExecutions`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other provided parameters must match + the call that provided the page token. (Otherwise the + request will fail with INVALID_ARGUMENT error.) + filter (str): + A query to filter available Executions for matching results. + Current implementation supports filtering on fields: + + 1) display_name e.g display_name = "test_name" + 2) state e.g. state = RUNNING + 3) create_time and update_time e.g create_time > + "2020-12-17T13:25:12-08:00" + 4) metadata e.g metadata.flag.number_value > 1 + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + filter = proto.Field(proto.STRING, number=4) + + +class ListExecutionsResponse(proto.Message): + r"""Response message for + ``MetadataService.ListExecutions``. + + Attributes: + executions (Sequence[google.cloud.aiplatform_v1beta1.types.Execution]): + The Executions retrieved from the + MetadataStore. + next_page_token (str): + A token, which can be sent as + [MetadataService.ListExecutions.page_token][] to retrieve + the next page. If this field is not populated, there are no + subsequent pages. + """ + + @property + def raw_page(self): + return self + + executions = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_execution.Execution, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateExecutionRequest(proto.Message): + r"""Request message for + ``MetadataService.UpdateExecution``. + + Attributes: + execution (google.cloud.aiplatform_v1beta1.types.Execution): + Required. The Execution containing updates. The Execution's + ``Execution.name`` + field is used to identify the Execution to be updated. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. A FieldMask indicating which fields + should be updated. + allow_missing (bool): + If set to true, and the + ``Execution`` is + not found, a new + ``Execution`` will + be created. In this situation, ``update_mask`` is ignored. + """ + + execution = proto.Field(proto.MESSAGE, number=1, + message=gca_execution.Execution, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + allow_missing = proto.Field(proto.BOOL, number=3) + + +class AddExecutionEventsRequest(proto.Message): + r"""Request message for + ``MetadataService.AddExecutionEvents``. + + Attributes: + execution (str): + Required. The resource name of the Execution + that the Events connect Artifacts with. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + events (Sequence[google.cloud.aiplatform_v1beta1.types.Event]): + The Events to create and add. + """ + + execution = proto.Field(proto.STRING, number=1) + + events = proto.RepeatedField(proto.MESSAGE, number=2, + message=event.Event, + ) + + +class AddExecutionEventsResponse(proto.Message): + r"""Response message for + ``MetadataService.AddExecutionEvents``. + """ + + +class QueryExecutionInputsAndOutputsRequest(proto.Message): + r"""Request message for + ``MetadataService.QueryExecutionInputsAndOutputs``. + + Attributes: + execution (str): + Required. The resource name of the Execution + whose input and output Artifacts should be + retrieved as a LineageSubgraph. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + """ + + execution = proto.Field(proto.STRING, number=1) + + +class CreateMetadataSchemaRequest(proto.Message): + r"""Request message for + ``MetadataService.CreateMetadataSchema``. + + Attributes: + parent (str): + Required. The resource name of the + MetadataStore where the MetadataSchema should be + created. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + metadata_schema (google.cloud.aiplatform_v1beta1.types.MetadataSchema): + Required. The MetadataSchema to create. + metadata_schema_id (str): + The {metadata_schema} portion of the resource name with the + format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/metadataSchemas/{metadataschema} + If not provided, the MetadataStore's ID will be a UUID + generated by the service. Must be 4-128 characters in + length. Valid characters are /[a-z][0-9]-/. Must be unique + across all MetadataSchemas in the parent Location. + (Otherwise the request will fail with ALREADY_EXISTS, or + PERMISSION_DENIED if the caller can't view the preexisting + MetadataSchema.) + """ + + parent = proto.Field(proto.STRING, number=1) + + metadata_schema = proto.Field(proto.MESSAGE, number=2, + message=gca_metadata_schema.MetadataSchema, + ) + + metadata_schema_id = proto.Field(proto.STRING, number=3) + + +class GetMetadataSchemaRequest(proto.Message): + r"""Request message for + ``MetadataService.GetMetadataSchema``. + + Attributes: + name (str): + Required. The resource name of the + MetadataSchema to retrieve. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/metadataSchemas/{metadataschema} + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListMetadataSchemasRequest(proto.Message): + r"""Request message for + ``MetadataService.ListMetadataSchemas``. + + Attributes: + parent (str): + Required. The MetadataStore whose + MetadataSchemas should be listed. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + page_size (int): + The maximum number of MetadataSchemas to + return. The service may return fewer. + Must be in range 1-1000, inclusive. Defaults to + 100. + page_token (str): + A page token, received from a previous + ``MetadataService.ListMetadataSchemas`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other provided parameters must match + the call that provided the page token. (Otherwise the + request will fail with INVALID_ARGUMENT error.) + filter (str): + A query to filter available MetadataSchemas + for matching results. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + filter = proto.Field(proto.STRING, number=4) + + +class ListMetadataSchemasResponse(proto.Message): + r"""Response message for + ``MetadataService.ListMetadataSchemas``. + + Attributes: + metadata_schemas (Sequence[google.cloud.aiplatform_v1beta1.types.MetadataSchema]): + The MetadataSchemas found for the + MetadataStore. + next_page_token (str): + A token, which can be sent as + [MetadataService.ListMetadataSchemas.page_token][] to + retrieve the next page. If this field is not populated, + there are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + metadata_schemas = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_metadata_schema.MetadataSchema, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_store.py b/google/cloud/aiplatform_v1beta1/types/metadata_store.py new file mode 100644 index 0000000000..da4704e31d --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/metadata_store.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'MetadataStore', + }, +) + + +class MetadataStore(proto.Message): + r"""Instance of a metadata store. Contains a set of metadata that + can be queried. + + Attributes: + name (str): + Output only. The resource name of the + MetadataStore instance. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + MetadataStore was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + MetadataStore was last updated. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for an + Metadata Store. If set, this Metadata Store and + all sub-resources of this Metadata Store will be + secured by this key. + """ + + name = proto.Field(proto.STRING, number=1) + + create_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + + encryption_spec = proto.Field(proto.MESSAGE, number=5, + message=gca_encryption_spec.EncryptionSpec, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py index 9a695ea349..07f9565af6 100644 --- a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py +++ b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py @@ -22,7 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"MigratableResource",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'MigratableResource', + }, ) @@ -52,7 +55,6 @@ class MigratableResource(proto.Message): Output only. Timestamp when this MigratableResource was last updated. """ - class MlEngineModelVersion(proto.Message): r"""Represents one model Version in ml.googleapis.com. @@ -121,7 +123,6 @@ class DataLabelingDataset(proto.Message): datalabeling.googleapis.com belongs to the data labeling Dataset. """ - class DataLabelingAnnotatedDataset(proto.Message): r"""Represents one AnnotatedDataset in datalabeling.googleapis.com. @@ -144,34 +145,32 @@ class DataLabelingAnnotatedDataset(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=4) - data_labeling_annotated_datasets = proto.RepeatedField( - proto.MESSAGE, - number=3, - message="MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset", + data_labeling_annotated_datasets = proto.RepeatedField(proto.MESSAGE, number=3, + message='MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset', ) - ml_engine_model_version = proto.Field( - proto.MESSAGE, number=1, oneof="resource", message=MlEngineModelVersion, + ml_engine_model_version = proto.Field(proto.MESSAGE, number=1, oneof='resource', + message=MlEngineModelVersion, ) - automl_model = proto.Field( - proto.MESSAGE, number=2, oneof="resource", message=AutomlModel, + automl_model = proto.Field(proto.MESSAGE, number=2, oneof='resource', + message=AutomlModel, ) - automl_dataset = proto.Field( - proto.MESSAGE, number=3, oneof="resource", message=AutomlDataset, + automl_dataset = proto.Field(proto.MESSAGE, number=3, oneof='resource', + message=AutomlDataset, ) - data_labeling_dataset = proto.Field( - proto.MESSAGE, number=4, oneof="resource", message=DataLabelingDataset, + data_labeling_dataset = proto.Field(proto.MESSAGE, number=4, oneof='resource', + message=DataLabelingDataset, ) - last_migrate_time = proto.Field( - proto.MESSAGE, number=5, message=timestamp.Timestamp, + last_migrate_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, ) - last_update_time = proto.Field( - proto.MESSAGE, number=6, message=timestamp.Timestamp, + last_update_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, ) diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py index de4c9466f6..ec23daf2ff 100644 --- a/google/cloud/aiplatform_v1beta1/types/migration_service.py +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -18,23 +18,21 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - migratable_resource as gca_migratable_resource, -) +from google.cloud.aiplatform_v1beta1.types import migratable_resource as gca_migratable_resource from google.cloud.aiplatform_v1beta1.types import operation from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", - "BatchMigrateResourcesRequest", - "MigrateResourceRequest", - "BatchMigrateResourcesResponse", - "MigrateResourceResponse", - "BatchMigrateResourcesOperationMetadata", + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'BatchMigrateResourcesRequest', + 'MigrateResourceRequest', + 'BatchMigrateResourcesResponse', + 'MigrateResourceResponse', + 'BatchMigrateResourcesOperationMetadata', }, ) @@ -101,8 +99,8 @@ class SearchMigratableResourcesResponse(proto.Message): def raw_page(self): return self - migratable_resources = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_migratable_resource.MigratableResource, + migratable_resources = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_migratable_resource.MigratableResource, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -125,8 +123,8 @@ class BatchMigrateResourcesRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - migrate_resource_requests = proto.RepeatedField( - proto.MESSAGE, number=2, message="MigrateResourceRequest", + migrate_resource_requests = proto.RepeatedField(proto.MESSAGE, number=2, + message='MigrateResourceRequest', ) @@ -150,7 +148,6 @@ class MigrateResourceRequest(proto.Message): datalabeling.googleapis.com to AI Platform's Dataset. """ - class MigrateMlEngineModelVersionConfig(proto.Message): r"""Config for migrating version in ml.googleapis.com to AI Platform's Model. @@ -238,7 +235,6 @@ class MigrateDataLabelingDatasetConfig(proto.Message): AnnotatedDatasets have to belong to the datalabeling Dataset. """ - class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): r"""Config for migrating AnnotatedDataset in datalabeling.googleapis.com to AI Platform's SavedQuery. @@ -256,31 +252,23 @@ class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=2) - migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField( - proto.MESSAGE, - number=3, - message="MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig", + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField(proto.MESSAGE, number=3, + message='MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig', ) - migrate_ml_engine_model_version_config = proto.Field( - proto.MESSAGE, - number=1, - oneof="request", + migrate_ml_engine_model_version_config = proto.Field(proto.MESSAGE, number=1, oneof='request', message=MigrateMlEngineModelVersionConfig, ) - migrate_automl_model_config = proto.Field( - proto.MESSAGE, number=2, oneof="request", message=MigrateAutomlModelConfig, + migrate_automl_model_config = proto.Field(proto.MESSAGE, number=2, oneof='request', + message=MigrateAutomlModelConfig, ) - migrate_automl_dataset_config = proto.Field( - proto.MESSAGE, number=3, oneof="request", message=MigrateAutomlDatasetConfig, + migrate_automl_dataset_config = proto.Field(proto.MESSAGE, number=3, oneof='request', + message=MigrateAutomlDatasetConfig, ) - migrate_data_labeling_dataset_config = proto.Field( - proto.MESSAGE, - number=4, - oneof="request", + migrate_data_labeling_dataset_config = proto.Field(proto.MESSAGE, number=4, oneof='request', message=MigrateDataLabelingDatasetConfig, ) @@ -294,8 +282,8 @@ class BatchMigrateResourcesResponse(proto.Message): Successfully migrated resources. """ - migrate_resource_responses = proto.RepeatedField( - proto.MESSAGE, number=1, message="MigrateResourceResponse", + migrate_resource_responses = proto.RepeatedField(proto.MESSAGE, number=1, + message='MigrateResourceResponse', ) @@ -313,12 +301,12 @@ class MigrateResourceResponse(proto.Message): datalabeling.googleapis.com. """ - dataset = proto.Field(proto.STRING, number=1, oneof="migrated_resource") + dataset = proto.Field(proto.STRING, number=1, oneof='migrated_resource') - model = proto.Field(proto.STRING, number=2, oneof="migrated_resource") + model = proto.Field(proto.STRING, number=2, oneof='migrated_resource') - migratable_resource = proto.Field( - proto.MESSAGE, number=3, message=gca_migratable_resource.MigratableResource, + migratable_resource = proto.Field(proto.MESSAGE, number=3, + message=gca_migratable_resource.MigratableResource, ) @@ -333,7 +321,6 @@ class BatchMigrateResourcesOperationMetadata(proto.Message): Partial results that reflect the latest migration operation progress. """ - class PartialResult(proto.Message): r"""Represents a partial result in batch migration operation for one ``MigrateResourceRequest``. @@ -351,24 +338,24 @@ class PartialResult(proto.Message): [MigrateResourceRequest.migrate_resource_requests][]. """ - error = proto.Field( - proto.MESSAGE, number=2, oneof="result", message=status.Status, + error = proto.Field(proto.MESSAGE, number=2, oneof='result', + message=status.Status, ) - model = proto.Field(proto.STRING, number=3, oneof="result") + model = proto.Field(proto.STRING, number=3, oneof='result') - dataset = proto.Field(proto.STRING, number=4, oneof="result") + dataset = proto.Field(proto.STRING, number=4, oneof='result') - request = proto.Field( - proto.MESSAGE, number=1, message="MigrateResourceRequest", + request = proto.Field(proto.MESSAGE, number=1, + message='MigrateResourceRequest', ) - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) - partial_results = proto.RepeatedField( - proto.MESSAGE, number=2, message=PartialResult, + partial_results = proto.RepeatedField(proto.MESSAGE, number=2, + message=PartialResult, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 4dcf6baefa..aaa87f85bb 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -27,8 +27,13 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"Model", "PredictSchemata", "ModelContainerSpec", "Port",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Model', + 'PredictSchemata', + 'ModelContainerSpec', + 'Port', + }, ) @@ -249,7 +254,6 @@ class Model(proto.Message): Model. If set, this Model and all sub-resources of this Model will be secured by this key. """ - class DeploymentResourcesType(proto.Enum): r"""Identifies a type of Model's prediction resources.""" DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = 0 @@ -286,7 +290,6 @@ class ExportFormat(proto.Message): Output only. The content of this Model that may be exported. """ - class ExportableContent(proto.Enum): r"""The Model content that can be exported.""" EXPORTABLE_CONTENT_UNSPECIFIED = 0 @@ -295,8 +298,8 @@ class ExportableContent(proto.Enum): id = proto.Field(proto.STRING, number=1) - exportable_contents = proto.RepeatedField( - proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", + exportable_contents = proto.RepeatedField(proto.ENUM, number=2, + enum='Model.ExportFormat.ExportableContent', ) name = proto.Field(proto.STRING, number=1) @@ -305,48 +308,58 @@ class ExportableContent(proto.Enum): description = proto.Field(proto.STRING, number=3) - predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) + predict_schemata = proto.Field(proto.MESSAGE, number=4, + message='PredictSchemata', + ) metadata_schema_uri = proto.Field(proto.STRING, number=5) - metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + metadata = proto.Field(proto.MESSAGE, number=6, + message=struct.Value, + ) - supported_export_formats = proto.RepeatedField( - proto.MESSAGE, number=20, message=ExportFormat, + supported_export_formats = proto.RepeatedField(proto.MESSAGE, number=20, + message=ExportFormat, ) training_pipeline = proto.Field(proto.STRING, number=7) - container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) + container_spec = proto.Field(proto.MESSAGE, number=9, + message='ModelContainerSpec', + ) artifact_uri = proto.Field(proto.STRING, number=26) - supported_deployment_resources_types = proto.RepeatedField( - proto.ENUM, number=10, enum=DeploymentResourcesType, + supported_deployment_resources_types = proto.RepeatedField(proto.ENUM, number=10, + enum=DeploymentResourcesType, ) supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) - create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) - deployed_models = proto.RepeatedField( - proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, + deployed_models = proto.RepeatedField(proto.MESSAGE, number=15, + message=deployed_model_ref.DeployedModelRef, ) - explanation_spec = proto.Field( - proto.MESSAGE, number=23, message=explanation.ExplanationSpec, + explanation_spec = proto.Field(proto.MESSAGE, number=23, + message=explanation.ExplanationSpec, ) etag = proto.Field(proto.STRING, number=16) labels = proto.MapField(proto.STRING, proto.STRING, number=17) - encryption_spec = proto.Field( - proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=24, + message=gca_encryption_spec.EncryptionSpec, ) @@ -654,9 +667,13 @@ class ModelContainerSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) + env = proto.RepeatedField(proto.MESSAGE, number=4, + message=env_var.EnvVar, + ) - ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) + ports = proto.RepeatedField(proto.MESSAGE, number=5, + message='Port', + ) predict_route = proto.Field(proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py new file mode 100644 index 0000000000..c500a28a8b --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import feature_monitoring_stats +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.types import model_monitoring +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ModelDeploymentMonitoringObjectiveType', + 'ModelDeploymentMonitoringJob', + 'ModelDeploymentMonitoringBigQueryTable', + 'ModelDeploymentMonitoringObjectiveConfig', + 'ModelDeploymentMonitoringScheduleConfig', + 'ModelMonitoringStatsAnomalies', + }, +) + + +class ModelDeploymentMonitoringObjectiveType(proto.Enum): + r"""The Model Monitoring Objective types.""" + MODEL_DEPLOYMENT_MONITORING_OBJECTIVE_TYPE_UNSPECIFIED = 0 + RAW_FEATURE_SKEW = 1 + RAW_FEATURE_DRIFT = 2 + FEATURE_ATTRIBUTION_SKEW = 3 + FEATURE_ATTRIBUTION_DRIFT = 4 + + +class ModelDeploymentMonitoringJob(proto.Message): + r"""Represents a job that runs periodically to monitor the + deployed models in an endpoint. It will analyze the logged + training & prediction data to detect any abnormal behaviors. + + Attributes: + name (str): + Output only. Resource name of a + ModelDeploymentMonitoringJob. + display_name (str): + Required. The user-defined name of the + ModelDeploymentMonitoringJob. The name can be up + to 128 characters long and can be consist of any + UTF-8 characters. + Display name of a ModelDeploymentMonitoringJob. + endpoint (str): + Required. Endpoint resource name. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + state (google.cloud.aiplatform_v1beta1.types.JobState): + Output only. The detailed state of the + monitoring job. When the job is still creating, + the state will be 'PENDING'. Once the job is + successfully created, the state will be + 'RUNNING'. Pause the job, the state will be + 'PAUSED'. + Resume the job, the state will return to + 'RUNNING'. + schedule_state (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob.MonitoringScheduleState): + Output only. Schedule state when the + monitoring job is in Running state. + model_deployment_monitoring_objective_configs (Sequence[google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringObjectiveConfig]): + Required. The config for monitoring + objectives. This is a per DeployedModel config. + Each DeployedModel needs to be configed + separately. + model_deployment_monitoring_schedule_config (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringScheduleConfig): + Required. Schedule config for running the + monitoring job. + logging_sampling_strategy (google.cloud.aiplatform_v1beta1.types.SamplingStrategy): + Required. Sample Strategy for logging. + model_monitoring_alert_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringAlertConfig): + Alert config for model monitoring. + predict_instance_schema_uri (str): + YAML schema file uri describing the format of + a single instance, which are given to format + this Endpoint's prediction (and explanation). If + not set, we will generate predict schema from + collected predict requests. + sample_predict_instance (google.protobuf.struct_pb2.Value): + Sample Predict instance, same format as + ``PredictRequest.instances``, + this can be set as a replacement of + ``ModelDeploymentMonitoringJob.predict_instance_schema_uri``. + If not set, we will generate predict schema from collected + predict requests. + analysis_instance_schema_uri (str): + YAML schema file uri describing the format of a single + instance that you want Tensorflow Data Validation (TFDV) to + analyze. + + If this field is empty, all the feature data types are + inferred from + ``predict_instance_schema_uri``, + meaning that TFDV will use the data in the exact format(data + type) as prediction request/response. If there are any data + type differences between predict instance and TFDV instance, + this field can be used to override the schema. For models + trained with AI Platform, this field must be set as all the + fields in predict instance formatted as string. + bigquery_tables (Sequence[google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringBigQueryTable]): + Output only. The created bigquery tables for + the job under customer project. Customer could + do their own query & analysis. There could be 4 + log tables in maximum: + 1. Training data logging predict + request/response 2. Serving data logging predict + request/response + log_ttl (google.protobuf.duration_pb2.Duration): + The TTL of BigQuery tables in user projects + which stores logs. A day is the basic unit of + the TTL and we take the ceil of TTL/86400(a + day). e.g. { second: 3600} indicates ttl = 1 + day. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob.LabelsEntry]): + The labels with user-defined metadata to + organize your ModelDeploymentMonitoringJob. + + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + ModelDeploymentMonitoringJob was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + ModelDeploymentMonitoringJob was updated most + recently. + next_schedule_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this monitoring + pipeline will be scheduled to run for the next + round. + stats_anomalies_base_directory (google.cloud.aiplatform_v1beta1.types.GcsDestination): + Stats anomalies base folder path. + """ + class MonitoringScheduleState(proto.Enum): + r"""The state to Specify the monitoring pipeline.""" + MONITORING_SCHEDULE_STATE_UNSPECIFIED = 0 + PENDING = 1 + OFFLINE = 2 + RUNNING = 3 + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + endpoint = proto.Field(proto.STRING, number=3) + + state = proto.Field(proto.ENUM, number=4, + enum=job_state.JobState, + ) + + schedule_state = proto.Field(proto.ENUM, number=5, + enum=MonitoringScheduleState, + ) + + model_deployment_monitoring_objective_configs = proto.RepeatedField(proto.MESSAGE, number=6, + message='ModelDeploymentMonitoringObjectiveConfig', + ) + + model_deployment_monitoring_schedule_config = proto.Field(proto.MESSAGE, number=7, + message='ModelDeploymentMonitoringScheduleConfig', + ) + + logging_sampling_strategy = proto.Field(proto.MESSAGE, number=8, + message=model_monitoring.SamplingStrategy, + ) + + model_monitoring_alert_config = proto.Field(proto.MESSAGE, number=15, + message=model_monitoring.ModelMonitoringAlertConfig, + ) + + predict_instance_schema_uri = proto.Field(proto.STRING, number=9) + + sample_predict_instance = proto.Field(proto.MESSAGE, number=19, + message=struct.Value, + ) + + analysis_instance_schema_uri = proto.Field(proto.STRING, number=16) + + bigquery_tables = proto.RepeatedField(proto.MESSAGE, number=10, + message='ModelDeploymentMonitoringBigQueryTable', + ) + + log_ttl = proto.Field(proto.MESSAGE, number=17, + message=duration.Duration, + ) + + labels = proto.MapField(proto.STRING, proto.STRING, number=11) + + create_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) + + next_schedule_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) + + stats_anomalies_base_directory = proto.Field(proto.MESSAGE, number=20, + message=io.GcsDestination, + ) + + +class ModelDeploymentMonitoringBigQueryTable(proto.Message): + r"""ModelDeploymentMonitoringBigQueryTable specifies the BigQuery + table name as well as some information of the logs stored in + this table. + + Attributes: + log_source (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringBigQueryTable.LogSource): + The source of log. + log_type (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringBigQueryTable.LogType): + The type of log. + bigquery_table_path (str): + The created BigQuery table to store logs. Customer could do + their own query & analysis. Format: + ``bq://.model_deployment_monitoring_._`` + """ + class LogSource(proto.Enum): + r"""Indicates where does the log come from.""" + LOG_SOURCE_UNSPECIFIED = 0 + TRAINING = 1 + SERVING = 2 + + class LogType(proto.Enum): + r"""Indicates what type of traffic does the log belong to.""" + LOG_TYPE_UNSPECIFIED = 0 + PREDICT = 1 + EXPLAIN = 2 + + log_source = proto.Field(proto.ENUM, number=1, + enum=LogSource, + ) + + log_type = proto.Field(proto.ENUM, number=2, + enum=LogType, + ) + + bigquery_table_path = proto.Field(proto.STRING, number=3) + + +class ModelDeploymentMonitoringObjectiveConfig(proto.Message): + r"""ModelDeploymentMonitoringObjectiveConfig contains the pair of + deployed_model_id to ModelMonitoringObjectiveConfig. + + Attributes: + deployed_model_id (str): + The DeployedModel ID of the objective config. + objective_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig): + The objective config of for the + modelmonitoring job of this deployed model. + """ + + deployed_model_id = proto.Field(proto.STRING, number=1) + + objective_config = proto.Field(proto.MESSAGE, number=2, + message=model_monitoring.ModelMonitoringObjectiveConfig, + ) + + +class ModelDeploymentMonitoringScheduleConfig(proto.Message): + r"""The config for scheduling monitoring job. + + Attributes: + monitor_interval (google.protobuf.duration_pb2.Duration): + Required. The model monitoring job running + interval. It will be rounded up to next full + hour. + """ + + monitor_interval = proto.Field(proto.MESSAGE, number=1, + message=duration.Duration, + ) + + +class ModelMonitoringStatsAnomalies(proto.Message): + r"""Statistics and anomalies generated by Model Monitoring. + + Attributes: + objective (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringObjectiveType): + Model Monitoring Objective those stats and + anomalies belonging to. + deployed_model_id (str): + Deployed Model ID. + anomaly_count (int): + Number of anomalies within all stats. + feature_stats (Sequence[google.cloud.aiplatform_v1beta1.types.ModelMonitoringStatsAnomalies.FeatureHistoricStatsAnomalies]): + A list of historical Stats and Anomalies + generated for all Features. + """ + class FeatureHistoricStatsAnomalies(proto.Message): + r"""Historical Stats (and Anomalies) for a specific Feature. + + Attributes: + feature_display_name (str): + Display Name of the Feature. + threshold (google.cloud.aiplatform_v1beta1.types.ThresholdConfig): + Threshold for anomaly detection. + training_stats (google.cloud.aiplatform_v1beta1.types.FeatureStatsAnomaly): + Stats calculated for the Training Dataset. + prediction_stats (Sequence[google.cloud.aiplatform_v1beta1.types.FeatureStatsAnomaly]): + A list of historical stats generated by + different time window's Prediction Dataset. + """ + + feature_display_name = proto.Field(proto.STRING, number=1) + + threshold = proto.Field(proto.MESSAGE, number=3, + message=model_monitoring.ThresholdConfig, + ) + + training_stats = proto.Field(proto.MESSAGE, number=4, + message=feature_monitoring_stats.FeatureStatsAnomaly, + ) + + prediction_stats = proto.RepeatedField(proto.MESSAGE, number=5, + message=feature_monitoring_stats.FeatureStatsAnomaly, + ) + + objective = proto.Field(proto.ENUM, number=1, + enum='ModelDeploymentMonitoringObjectiveType', + ) + + deployed_model_id = proto.Field(proto.STRING, number=2) + + anomaly_count = proto.Field(proto.INT32, number=3) + + feature_stats = proto.RepeatedField(proto.MESSAGE, number=4, + message=FeatureHistoricStatsAnomalies, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py index 391bc38cf4..d0a4a5a146 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py @@ -24,7 +24,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluation",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ModelEvaluation', + }, ) @@ -71,7 +74,6 @@ class ModelEvaluation(proto.Message): that are used for explaining the predicted values on the evaluated data. """ - class ModelEvaluationExplanationSpec(proto.Message): r""" @@ -89,26 +91,30 @@ class ModelEvaluationExplanationSpec(proto.Message): explanation_type = proto.Field(proto.STRING, number=1) - explanation_spec = proto.Field( - proto.MESSAGE, number=2, message=explanation.ExplanationSpec, + explanation_spec = proto.Field(proto.MESSAGE, number=2, + message=explanation.ExplanationSpec, ) name = proto.Field(proto.STRING, number=1) metrics_schema_uri = proto.Field(proto.STRING, number=2) - metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + metrics = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) slice_dimensions = proto.RepeatedField(proto.STRING, number=5) - model_explanation = proto.Field( - proto.MESSAGE, number=8, message=explanation.ModelExplanation, + model_explanation = proto.Field(proto.MESSAGE, number=8, + message=explanation.ModelExplanation, ) - explanation_specs = proto.RepeatedField( - proto.MESSAGE, number=9, message=ModelEvaluationExplanationSpec, + explanation_specs = proto.RepeatedField(proto.MESSAGE, number=9, + message=ModelEvaluationExplanationSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py index 2d66e29a9f..3895dd1170 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py @@ -23,7 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluationSlice",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ModelEvaluationSlice', + }, ) @@ -54,7 +57,6 @@ class ModelEvaluationSlice(proto.Message): Output only. Timestamp when this ModelEvaluationSlice was created. """ - class Slice(proto.Message): r"""Definition of a slice. @@ -79,13 +81,19 @@ class Slice(proto.Message): name = proto.Field(proto.STRING, number=1) - slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) + slice_ = proto.Field(proto.MESSAGE, number=2, + message=Slice, + ) metrics_schema_uri = proto.Field(proto.STRING, number=3) - metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + metrics = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) - create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_monitoring.py b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py new file mode 100644 index 0000000000..f57417be64 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import io + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ModelMonitoringObjectiveConfig', + 'ModelMonitoringAlertConfig', + 'ThresholdConfig', + 'SamplingStrategy', + }, +) + + +class ModelMonitoringObjectiveConfig(proto.Message): + r"""Next ID: 6 + + Attributes: + training_dataset (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.TrainingDataset): + Training dataset for models. This field has + to be set only if + TrainingPredictionSkewDetectionConfig is + specified. + training_prediction_skew_detection_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig): + The config for skew between training data and + prediction data. + prediction_drift_detection_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig): + The config for drift of prediction data. + """ + class TrainingDataset(proto.Message): + r"""Training Dataset information. + + Attributes: + dataset (str): + The resource name of the Dataset used to + train this Model. + gcs_source (google.cloud.aiplatform_v1beta1.types.GcsSource): + The Google Cloud Storage uri of the unmanaged + Dataset used to train this Model. + bigquery_source (google.cloud.aiplatform_v1beta1.types.BigQuerySource): + The BigQuery table of the unmanaged Dataset + used to train this Model. + data_format (str): + Data format of the dataset, only applicable + if the input is from Google Cloud Storage. + The possible formats are: + + "tf-record" + The source file is a TFRecord file. + + "csv" + The source file is a CSV file. + target_field (str): + The target field name the model is to + predict. This field will be excluded when doing + Predict and (or) Explain for the training data. + logging_sampling_strategy (google.cloud.aiplatform_v1beta1.types.SamplingStrategy): + Strategy to sample data from Training + Dataset. If not set, we process the whole + dataset. + """ + + dataset = proto.Field(proto.STRING, number=3, oneof='data_source') + + gcs_source = proto.Field(proto.MESSAGE, number=4, oneof='data_source', + message=io.GcsSource, + ) + + bigquery_source = proto.Field(proto.MESSAGE, number=5, oneof='data_source', + message=io.BigQuerySource, + ) + + data_format = proto.Field(proto.STRING, number=2) + + target_field = proto.Field(proto.STRING, number=6) + + logging_sampling_strategy = proto.Field(proto.MESSAGE, number=7, + message='SamplingStrategy', + ) + + class TrainingPredictionSkewDetectionConfig(proto.Message): + r"""The config for Training & Prediction data skew detection. It + specifies the training dataset sources and the skew detection + parameters. + + Attributes: + skew_thresholds (Sequence[google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig.SkewThresholdsEntry]): + Key is the feature name and value is the + threshold. If a feature needs to be monitored + for skew, a value threshold must be configed for + that feature. The threshold here is against + feature distribution distance between the + training and prediction feature. + """ + + skew_thresholds = proto.MapField(proto.STRING, proto.MESSAGE, number=1, + message='ThresholdConfig', + ) + + class PredictionDriftDetectionConfig(proto.Message): + r"""The config for Prediction data drift detection. + + Attributes: + drift_thresholds (Sequence[google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig.DriftThresholdsEntry]): + Key is the feature name and value is the + threshold. If a feature needs to be monitored + for drift, a value threshold must be configed + for that feature. The threshold here is against + feature distribution distance between different + time windws. + """ + + drift_thresholds = proto.MapField(proto.STRING, proto.MESSAGE, number=1, + message='ThresholdConfig', + ) + + training_dataset = proto.Field(proto.MESSAGE, number=1, + message=TrainingDataset, + ) + + training_prediction_skew_detection_config = proto.Field(proto.MESSAGE, number=2, + message=TrainingPredictionSkewDetectionConfig, + ) + + prediction_drift_detection_config = proto.Field(proto.MESSAGE, number=3, + message=PredictionDriftDetectionConfig, + ) + + +class ModelMonitoringAlertConfig(proto.Message): + r"""Next ID: 2 + + Attributes: + email_alert_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringAlertConfig.EmailAlertConfig): + Email alert config. + """ + class EmailAlertConfig(proto.Message): + r"""The config for email alert. + + Attributes: + user_emails (Sequence[str]): + The email addresses to send the alert. + """ + + user_emails = proto.RepeatedField(proto.STRING, number=1) + + email_alert_config = proto.Field(proto.MESSAGE, number=1, oneof='alert', + message=EmailAlertConfig, + ) + + +class ThresholdConfig(proto.Message): + r"""The config for feature monitoring threshold. + Next ID: 3 + + Attributes: + value (float): + Specify a threshold value that can trigger + the alert. If this threshold config is for + feature distribution distance: 1. For + categorical feature, the distribution distance + is calculated by L-inifinity norm. + 2. For numerical feature, the distribution + distance is calculated by Jensen–Shannon + divergence. + Each feature must have a non-zero threshold if + they need to be monitored. Otherwise no alert + will be triggered for that feature. + """ + + value = proto.Field(proto.DOUBLE, number=1, oneof='threshold') + + +class SamplingStrategy(proto.Message): + r"""Sampling Strategy for logging, can be for both training and + prediction dataset. + Next ID: 2 + + Attributes: + random_sample_config (google.cloud.aiplatform_v1beta1.types.SamplingStrategy.RandomSampleConfig): + Random sample config. Will support more + sampling strategies later. + """ + class RandomSampleConfig(proto.Message): + r"""Requests are randomly selected. + + Attributes: + sample_rate (float): + Sample rate (0, 1] + """ + + sample_rate = proto.Field(proto.DOUBLE, number=1) + + random_sample_config = proto.Field(proto.MESSAGE, number=1, + message=RandomSampleConfig, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index e0d8e148ab..46b5328166 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -27,25 +27,25 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "UploadModelRequest", - "UploadModelOperationMetadata", - "UploadModelResponse", - "GetModelRequest", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", - "DeleteModelRequest", - "ExportModelRequest", - "ExportModelOperationMetadata", - "ExportModelResponse", - "GetModelEvaluationRequest", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "GetModelEvaluationSliceRequest", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", + 'UploadModelRequest', + 'UploadModelOperationMetadata', + 'UploadModelResponse', + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'UpdateModelRequest', + 'DeleteModelRequest', + 'ExportModelRequest', + 'ExportModelOperationMetadata', + 'ExportModelResponse', + 'GetModelEvaluationRequest', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'GetModelEvaluationSliceRequest', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', }, ) @@ -65,7 +65,9 @@ class UploadModelRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) + model = proto.Field(proto.MESSAGE, number=2, + message=gca_model.Model, + ) class UploadModelOperationMetadata(proto.Message): @@ -78,8 +80,8 @@ class UploadModelOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -159,7 +161,9 @@ class ListModelsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelsResponse(proto.Message): @@ -179,7 +183,9 @@ class ListModelsResponse(proto.Message): def raw_page(self): return self - models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) + models = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_model.Model, + ) next_page_token = proto.Field(proto.STRING, number=2) @@ -198,9 +204,13 @@ class UpdateModelRequest(proto.Message): `FieldMask `__. """ - model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) + model = proto.Field(proto.MESSAGE, number=1, + message=gca_model.Model, + ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class DeleteModelRequest(proto.Message): @@ -229,7 +239,6 @@ class ExportModelRequest(proto.Message): Required. The desired output location and configuration. """ - class OutputConfig(proto.Message): r"""Output configuration for the Model export. @@ -261,17 +270,19 @@ class OutputConfig(proto.Message): export_format_id = proto.Field(proto.STRING, number=1) - artifact_destination = proto.Field( - proto.MESSAGE, number=3, message=io.GcsDestination, + artifact_destination = proto.Field(proto.MESSAGE, number=3, + message=io.GcsDestination, ) - image_destination = proto.Field( - proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, + image_destination = proto.Field(proto.MESSAGE, number=4, + message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) - output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) + output_config = proto.Field(proto.MESSAGE, number=2, + message=OutputConfig, + ) class ExportModelOperationMetadata(proto.Message): @@ -286,7 +297,6 @@ class ExportModelOperationMetadata(proto.Message): Output only. Information further describing the output of this Model export. """ - class OutputInfo(proto.Message): r"""Further describes the output of the ExportModel. Supplements ``ExportModelRequest.OutputConfig``. @@ -308,11 +318,13 @@ class OutputInfo(proto.Message): image_output_uri = proto.Field(proto.STRING, number=3) - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) - output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) + output_info = proto.Field(proto.MESSAGE, number=2, + message=OutputInfo, + ) class ExportModelResponse(proto.Message): @@ -366,7 +378,9 @@ class ListModelEvaluationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelEvaluationsResponse(proto.Message): @@ -387,8 +401,8 @@ class ListModelEvaluationsResponse(proto.Message): def raw_page(self): return self - model_evaluations = proto.RepeatedField( - proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, + model_evaluations = proto.RepeatedField(proto.MESSAGE, number=1, + message=model_evaluation.ModelEvaluation, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -441,7 +455,9 @@ class ListModelEvaluationSlicesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelEvaluationSlicesResponse(proto.Message): @@ -462,8 +478,8 @@ class ListModelEvaluationSlicesResponse(proto.Message): def raw_page(self): return self - model_evaluation_slices = proto.RepeatedField( - proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, + model_evaluation_slices = proto.RepeatedField(proto.MESSAGE, number=1, + message=model_evaluation_slice.ModelEvaluationSlice, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index 90565867e8..887e903ff2 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -23,8 +23,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"GenericOperationMetadata", "DeleteOperationMetadata",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'GenericOperationMetadata', + 'DeleteOperationMetadata', + }, ) @@ -48,13 +51,17 @@ class GenericOperationMetadata(proto.Message): finish time. """ - partial_failures = proto.RepeatedField( - proto.MESSAGE, number=1, message=status.Status, + partial_failures = proto.RepeatedField(proto.MESSAGE, number=1, + message=status.Status, ) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=2, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) class DeleteOperationMetadata(proto.Message): @@ -65,8 +72,8 @@ class DeleteOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message="GenericOperationMetadata", + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message='GenericOperationMetadata', ) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index b06361dfa9..a5add3f9ca 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -18,21 +18,19 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateTrainingPipelineRequest", - "GetTrainingPipelineRequest", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "DeleteTrainingPipelineRequest", - "CancelTrainingPipelineRequest", + 'CreateTrainingPipelineRequest', + 'GetTrainingPipelineRequest', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'DeleteTrainingPipelineRequest', + 'CancelTrainingPipelineRequest', }, ) @@ -52,8 +50,8 @@ class CreateTrainingPipelineRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - training_pipeline = proto.Field( - proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, + training_pipeline = proto.Field(proto.MESSAGE, number=2, + message=gca_training_pipeline.TrainingPipeline, ) @@ -115,7 +113,9 @@ class ListTrainingPipelinesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListTrainingPipelinesResponse(proto.Message): @@ -136,8 +136,8 @@ class ListTrainingPipelinesResponse(proto.Message): def raw_page(self): return self - training_pipelines = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, + training_pipelines = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_training_pipeline.TrainingPipeline, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py index cede653bd6..b04954f602 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"PipelineState",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'PipelineState', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index f7abe9e3e2..24011ca24d 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -23,12 +23,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "PredictRequest", - "PredictResponse", - "ExplainRequest", - "ExplainResponse", + 'PredictRequest', + 'PredictResponse', + 'ExplainRequest', + 'ExplainResponse', }, ) @@ -65,9 +65,13 @@ class PredictRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) + instances = proto.RepeatedField(proto.MESSAGE, number=2, + message=struct.Value, + ) - parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + parameters = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) class PredictResponse(proto.Message): @@ -87,7 +91,9 @@ class PredictResponse(proto.Message): served this prediction. """ - predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) + predictions = proto.RepeatedField(proto.MESSAGE, number=1, + message=struct.Value, + ) deployed_model_id = proto.Field(proto.STRING, number=2) @@ -139,12 +145,16 @@ class ExplainRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) + instances = proto.RepeatedField(proto.MESSAGE, number=2, + message=struct.Value, + ) - parameters = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + parameters = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) - explanation_spec_override = proto.Field( - proto.MESSAGE, number=5, message=explanation.ExplanationSpecOverride, + explanation_spec_override = proto.Field(proto.MESSAGE, number=5, + message=explanation.ExplanationSpecOverride, ) deployed_model_id = proto.Field(proto.STRING, number=3) @@ -171,13 +181,15 @@ class ExplainResponse(proto.Message): ``PredictResponse.predictions``. """ - explanations = proto.RepeatedField( - proto.MESSAGE, number=1, message=explanation.Explanation, + explanations = proto.RepeatedField(proto.MESSAGE, number=1, + message=explanation.Explanation, ) deployed_model_id = proto.Field(proto.STRING, number=2) - predictions = proto.RepeatedField(proto.MESSAGE, number=3, message=struct.Value,) + predictions = proto.RepeatedField(proto.MESSAGE, number=3, + message=struct.Value, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py index 4ac8c6a709..f75416157b 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"SpecialistPool",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'SpecialistPool', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index 3ed6593bd6..aa9e9235ef 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -24,16 +24,16 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateSpecialistPoolRequest", - "CreateSpecialistPoolOperationMetadata", - "GetSpecialistPoolRequest", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "DeleteSpecialistPoolRequest", - "UpdateSpecialistPoolRequest", - "UpdateSpecialistPoolOperationMetadata", + 'CreateSpecialistPoolRequest', + 'CreateSpecialistPoolOperationMetadata', + 'GetSpecialistPoolRequest', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'DeleteSpecialistPoolRequest', + 'UpdateSpecialistPoolRequest', + 'UpdateSpecialistPoolOperationMetadata', }, ) @@ -53,8 +53,8 @@ class CreateSpecialistPoolRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - specialist_pool = proto.Field( - proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field(proto.MESSAGE, number=2, + message=gca_specialist_pool.SpecialistPool, ) @@ -67,8 +67,8 @@ class CreateSpecialistPoolOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -113,7 +113,9 @@ class ListSpecialistPoolsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=3) - read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) + read_mask = proto.Field(proto.MESSAGE, number=4, + message=field_mask.FieldMask, + ) class ListSpecialistPoolsResponse(proto.Message): @@ -132,8 +134,8 @@ class ListSpecialistPoolsResponse(proto.Message): def raw_page(self): return self - specialist_pools = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + specialist_pools = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_specialist_pool.SpecialistPool, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -173,11 +175,13 @@ class UpdateSpecialistPoolRequest(proto.Message): resource. """ - specialist_pool = proto.Field( - proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field(proto.MESSAGE, number=1, + message=gca_specialist_pool.SpecialistPool, ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class UpdateSpecialistPoolOperationMetadata(proto.Message): @@ -195,8 +199,8 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): specialist_pool = proto.Field(proto.STRING, number=1) - generic_metadata = proto.Field( - proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=2, + message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 092d3a3e2d..f1f28d8669 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -18,13 +18,19 @@ import proto # type: ignore +from google.protobuf import duration_pb2 as duration # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"Study", "Trial", "StudySpec", "Measurement",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Study', + 'Trial', + 'StudySpec', + 'Measurement', + }, ) @@ -51,7 +57,6 @@ class Study(proto.Message): Study is inactive. This should be empty if a study is ACTIVE or COMPLETED. """ - class State(proto.Enum): r"""Describes the Study state.""" STATE_UNSPECIFIED = 0 @@ -63,11 +68,17 @@ class State(proto.Enum): display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=3, message="StudySpec",) + study_spec = proto.Field(proto.MESSAGE, number=3, + message='StudySpec', + ) - state = proto.Field(proto.ENUM, number=4, enum=State,) + state = proto.Field(proto.ENUM, number=4, + enum=State, + ) - create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) inactive_reason = proto.Field(proto.STRING, number=6) @@ -91,17 +102,35 @@ class Trial(proto.Message): final_measurement (google.cloud.aiplatform_v1beta1.types.Measurement): Output only. The final measurement containing the objective value. + measurements (Sequence[google.cloud.aiplatform_v1beta1.types.Measurement]): + Output only. A list of measurements that are strictly + lexicographically ordered by their induced tuples (steps, + elapsed_duration). These are used for early stopping + computations. start_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the Trial was started. end_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the Trial's status changed to ``SUCCEEDED`` or ``INFEASIBLE``. + client_id (str): + Output only. The identifier of the client that originally + requested this Trial. Each client is identified by a unique + client_id. When a client asks for a suggestion, Vizier will + assign it a Trial. The client should evaluate the Trial, + complete it, and report back to Vizier. If suggestion is + asked again by same client_id before the Trial is completed, + the same Trial will be returned. Multiple clients with + different client_ids can ask for suggestions simultaneously, + each of them will get their own Trial. + infeasible_reason (str): + Output only. A human readable string describing why the + Trial is infeasible. This is set only if Trial state is + ``INFEASIBLE``. custom_job (str): Output only. The CustomJob name linked to the Trial. It's set for a HyperparameterTuningJob's Trial. """ - class State(proto.Enum): r"""Describes a Trial state.""" STATE_UNSPECIFIED = 0 @@ -129,21 +158,41 @@ class Parameter(proto.Message): parameter_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) + value = proto.Field(proto.MESSAGE, number=2, + message=struct.Value, + ) name = proto.Field(proto.STRING, number=1) id = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=3, enum=State,) + state = proto.Field(proto.ENUM, number=3, + enum=State, + ) + + parameters = proto.RepeatedField(proto.MESSAGE, number=4, + message=Parameter, + ) + + final_measurement = proto.Field(proto.MESSAGE, number=5, + message='Measurement', + ) - parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) + measurements = proto.RepeatedField(proto.MESSAGE, number=6, + message='Measurement', + ) + + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) - final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + client_id = proto.Field(proto.STRING, number=9) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + infeasible_reason = proto.Field(proto.STRING, number=10) custom_job = proto.Field(proto.STRING, number=11) @@ -176,7 +225,6 @@ class StudySpec(proto.Message): Describe which measurement selection type will be used """ - class Algorithm(proto.Enum): r"""The available search algorithms for the Study.""" ALGORITHM_UNSPECIFIED = 0 @@ -222,7 +270,6 @@ class MetricSpec(proto.Message): Required. The optimization goal of the metric. """ - class GoalType(proto.Enum): r"""The available types of optimization goals.""" GOAL_TYPE_UNSPECIFIED = 0 @@ -231,7 +278,9 @@ class GoalType(proto.Enum): metric_id = proto.Field(proto.STRING, number=1) - goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) + goal = proto.Field(proto.ENUM, number=2, + enum='StudySpec.MetricSpec.GoalType', + ) class ParameterSpec(proto.Message): r"""Represents a single parameter to optimize. @@ -259,7 +308,6 @@ class ParameterSpec(proto.Message): If two items in conditional_parameter_specs have the same name, they must have disjoint parent_value_condition. """ - class ScaleType(proto.Enum): r"""The type of scaling that should be applied to this parameter.""" SCALE_TYPE_UNSPECIFIED = 0 @@ -342,7 +390,6 @@ class ConditionalParameterSpec(proto.Message): Required. The spec for a conditional parameter. """ - class DiscreteValueCondition(proto.Message): r"""Represents the spec to match discrete values from parent parameter. @@ -384,69 +431,46 @@ class CategoricalValueCondition(proto.Message): values = proto.RepeatedField(proto.STRING, number=1) - parent_discrete_values = proto.Field( - proto.MESSAGE, - number=2, - oneof="parent_value_condition", - message="StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition", + parent_discrete_values = proto.Field(proto.MESSAGE, number=2, oneof='parent_value_condition', + message='StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition', ) - parent_int_values = proto.Field( - proto.MESSAGE, - number=3, - oneof="parent_value_condition", - message="StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition", + parent_int_values = proto.Field(proto.MESSAGE, number=3, oneof='parent_value_condition', + message='StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition', ) - parent_categorical_values = proto.Field( - proto.MESSAGE, - number=4, - oneof="parent_value_condition", - message="StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition", + parent_categorical_values = proto.Field(proto.MESSAGE, number=4, oneof='parent_value_condition', + message='StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition', ) - parameter_spec = proto.Field( - proto.MESSAGE, number=1, message="StudySpec.ParameterSpec", + parameter_spec = proto.Field(proto.MESSAGE, number=1, + message='StudySpec.ParameterSpec', ) - double_value_spec = proto.Field( - proto.MESSAGE, - number=2, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.DoubleValueSpec", + double_value_spec = proto.Field(proto.MESSAGE, number=2, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.DoubleValueSpec', ) - integer_value_spec = proto.Field( - proto.MESSAGE, - number=3, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.IntegerValueSpec", + integer_value_spec = proto.Field(proto.MESSAGE, number=3, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.IntegerValueSpec', ) - categorical_value_spec = proto.Field( - proto.MESSAGE, - number=4, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.CategoricalValueSpec", + categorical_value_spec = proto.Field(proto.MESSAGE, number=4, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.CategoricalValueSpec', ) - discrete_value_spec = proto.Field( - proto.MESSAGE, - number=5, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.DiscreteValueSpec", + discrete_value_spec = proto.Field(proto.MESSAGE, number=5, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.DiscreteValueSpec', ) parameter_id = proto.Field(proto.STRING, number=1) - scale_type = proto.Field( - proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", + scale_type = proto.Field(proto.ENUM, number=6, + enum='StudySpec.ParameterSpec.ScaleType', ) - conditional_parameter_specs = proto.RepeatedField( - proto.MESSAGE, - number=10, - message="StudySpec.ParameterSpec.ConditionalParameterSpec", + conditional_parameter_specs = proto.RepeatedField(proto.MESSAGE, number=10, + message='StudySpec.ParameterSpec.ConditionalParameterSpec', ) class DecayCurveAutomatedStoppingSpec(proto.Message): @@ -535,37 +559,36 @@ class ConvexStopConfig(proto.Message): use_seconds = proto.Field(proto.BOOL, number=5) - decay_curve_stopping_spec = proto.Field( - proto.MESSAGE, - number=4, - oneof="automated_stopping_spec", + decay_curve_stopping_spec = proto.Field(proto.MESSAGE, number=4, oneof='automated_stopping_spec', message=DecayCurveAutomatedStoppingSpec, ) - median_automated_stopping_spec = proto.Field( - proto.MESSAGE, - number=5, - oneof="automated_stopping_spec", + median_automated_stopping_spec = proto.Field(proto.MESSAGE, number=5, oneof='automated_stopping_spec', message=MedianAutomatedStoppingSpec, ) - convex_stop_config = proto.Field( - proto.MESSAGE, - number=8, - oneof="automated_stopping_spec", + convex_stop_config = proto.Field(proto.MESSAGE, number=8, oneof='automated_stopping_spec', message=ConvexStopConfig, ) - metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, + message=MetricSpec, + ) - parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) + parameters = proto.RepeatedField(proto.MESSAGE, number=2, + message=ParameterSpec, + ) - algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) + algorithm = proto.Field(proto.ENUM, number=3, + enum=Algorithm, + ) - observation_noise = proto.Field(proto.ENUM, number=6, enum=ObservationNoise,) + observation_noise = proto.Field(proto.ENUM, number=6, + enum=ObservationNoise, + ) - measurement_selection_type = proto.Field( - proto.ENUM, number=7, enum=MeasurementSelectionType, + measurement_selection_type = proto.Field(proto.ENUM, number=7, + enum=MeasurementSelectionType, ) @@ -575,6 +598,9 @@ class Measurement(proto.Message): suggested hyperparameter values. Attributes: + elapsed_duration (google.protobuf.duration_pb2.Duration): + Output only. Time that the Trial has been + running at the point of this Measurement. step_count (int): Output only. The number of steps the machine learning model has been trained for. Must be @@ -584,7 +610,6 @@ class Measurement(proto.Message): evaluating the objective functions using suggested Parameter values. """ - class Metric(proto.Message): r"""A message representing a metric in the measurement. @@ -601,9 +626,15 @@ class Metric(proto.Message): value = proto.Field(proto.DOUBLE, number=2) + elapsed_duration = proto.Field(proto.MESSAGE, number=1, + message=duration.Duration, + ) + step_count = proto.Field(proto.INT64, number=2) - metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) + metrics = proto.RepeatedField(proto.MESSAGE, number=3, + message=Metric, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index 3c03b0f47d..84f1a7d2c6 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -28,14 +28,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "TrainingPipeline", - "InputDataConfig", - "FractionSplit", - "FilterSplit", - "PredefinedSplit", - "TimestampSplit", + 'TrainingPipeline', + 'InputDataConfig', + 'FractionSplit', + 'FilterSplit', + 'PredefinedSplit', + 'TimestampSplit', }, ) @@ -155,32 +155,52 @@ class TrainingPipeline(proto.Message): display_name = proto.Field(proto.STRING, number=2) - input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) + input_data_config = proto.Field(proto.MESSAGE, number=3, + message='InputDataConfig', + ) training_task_definition = proto.Field(proto.STRING, number=4) - training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + training_task_inputs = proto.Field(proto.MESSAGE, number=5, + message=struct.Value, + ) - training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + training_task_metadata = proto.Field(proto.MESSAGE, number=6, + message=struct.Value, + ) - model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) + model_to_upload = proto.Field(proto.MESSAGE, number=7, + message=model.Model, + ) - state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) + state = proto.Field(proto.ENUM, number=9, + enum=pipeline_state.PipelineState, + ) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + error = proto.Field(proto.MESSAGE, number=10, + message=status.Status, + ) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) labels = proto.MapField(proto.STRING, proto.STRING, number=15) - encryption_spec = proto.Field( - proto.MESSAGE, number=18, message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field(proto.MESSAGE, number=18, + message=gca_encryption_spec.EncryptionSpec, ) @@ -301,28 +321,28 @@ class InputDataConfig(proto.Message): ``annotation_schema_uri``. """ - fraction_split = proto.Field( - proto.MESSAGE, number=2, oneof="split", message="FractionSplit", + fraction_split = proto.Field(proto.MESSAGE, number=2, oneof='split', + message='FractionSplit', ) - filter_split = proto.Field( - proto.MESSAGE, number=3, oneof="split", message="FilterSplit", + filter_split = proto.Field(proto.MESSAGE, number=3, oneof='split', + message='FilterSplit', ) - predefined_split = proto.Field( - proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", + predefined_split = proto.Field(proto.MESSAGE, number=4, oneof='split', + message='PredefinedSplit', ) - timestamp_split = proto.Field( - proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", + timestamp_split = proto.Field(proto.MESSAGE, number=5, oneof='split', + message='TimestampSplit', ) - gcs_destination = proto.Field( - proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, + gcs_destination = proto.Field(proto.MESSAGE, number=8, oneof='destination', + message=io.GcsDestination, ) - bigquery_destination = proto.Field( - proto.MESSAGE, number=10, oneof="destination", message=io.BigQueryDestination, + bigquery_destination = proto.Field(proto.MESSAGE, number=10, oneof='destination', + message=io.BigQueryDestination, ) dataset_id = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index 25180ae567..a2ff3629c0 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"UserActionReference",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'UserActionReference', + }, ) @@ -39,13 +42,14 @@ class UserActionReference(proto.Message): LabelingJob. Format: 'projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}' method (str): - The method name of the API call. For example, - "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". + The method name of the API RPC call. For + example, + "/google.cloud.aiplatform.master.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1, oneof="reference") + operation = proto.Field(proto.STRING, number=1, oneof='reference') - data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") + data_labeling_job = proto.Field(proto.STRING, number=2, oneof='reference') method = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/vizier_service.py b/google/cloud/aiplatform_v1beta1/types/vizier_service.py index 2b837c476e..9a7b4be68f 100644 --- a/google/cloud/aiplatform_v1beta1/types/vizier_service.py +++ b/google/cloud/aiplatform_v1beta1/types/vizier_service.py @@ -24,30 +24,30 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "GetStudyRequest", - "CreateStudyRequest", - "ListStudiesRequest", - "ListStudiesResponse", - "DeleteStudyRequest", - "LookupStudyRequest", - "SuggestTrialsRequest", - "SuggestTrialsResponse", - "SuggestTrialsMetadata", - "CreateTrialRequest", - "GetTrialRequest", - "ListTrialsRequest", - "ListTrialsResponse", - "AddTrialMeasurementRequest", - "CompleteTrialRequest", - "DeleteTrialRequest", - "CheckTrialEarlyStoppingStateRequest", - "CheckTrialEarlyStoppingStateResponse", - "CheckTrialEarlyStoppingStateMetatdata", - "StopTrialRequest", - "ListOptimalTrialsRequest", - "ListOptimalTrialsResponse", + 'GetStudyRequest', + 'CreateStudyRequest', + 'ListStudiesRequest', + 'ListStudiesResponse', + 'DeleteStudyRequest', + 'LookupStudyRequest', + 'SuggestTrialsRequest', + 'SuggestTrialsResponse', + 'SuggestTrialsMetadata', + 'CreateTrialRequest', + 'GetTrialRequest', + 'ListTrialsRequest', + 'ListTrialsResponse', + 'AddTrialMeasurementRequest', + 'CompleteTrialRequest', + 'DeleteTrialRequest', + 'CheckTrialEarlyStoppingStateRequest', + 'CheckTrialEarlyStoppingStateResponse', + 'CheckTrialEarlyStoppingStateMetatdata', + 'StopTrialRequest', + 'ListOptimalTrialsRequest', + 'ListOptimalTrialsResponse', }, ) @@ -81,7 +81,9 @@ class CreateStudyRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - study = proto.Field(proto.MESSAGE, number=2, message=gca_study.Study,) + study = proto.Field(proto.MESSAGE, number=2, + message=gca_study.Study, + ) class ListStudiesRequest(proto.Message): @@ -127,7 +129,9 @@ class ListStudiesResponse(proto.Message): def raw_page(self): return self - studies = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Study,) + studies = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_study.Study, + ) next_page_token = proto.Field(proto.STRING, number=2) @@ -209,13 +213,21 @@ class SuggestTrialsResponse(proto.Message): completed. """ - trials = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Trial,) + trials = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_study.Trial, + ) - study_state = proto.Field(proto.ENUM, number=2, enum=gca_study.Study.State,) + study_state = proto.Field(proto.ENUM, number=2, + enum=gca_study.Study.State, + ) - start_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) - end_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) class SuggestTrialsMetadata(proto.Message): @@ -234,8 +246,8 @@ class SuggestTrialsMetadata(proto.Message): Trial if the last suggested Trial was completed. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) client_id = proto.Field(proto.STRING, number=2) @@ -256,7 +268,9 @@ class CreateTrialRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - trial = proto.Field(proto.MESSAGE, number=2, message=gca_study.Trial,) + trial = proto.Field(proto.MESSAGE, number=2, + message=gca_study.Trial, + ) class GetTrialRequest(proto.Message): @@ -315,7 +329,9 @@ class ListTrialsResponse(proto.Message): def raw_page(self): return self - trials = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Trial,) + trials = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_study.Trial, + ) next_page_token = proto.Field(proto.STRING, number=2) @@ -335,7 +351,9 @@ class AddTrialMeasurementRequest(proto.Message): trial_name = proto.Field(proto.STRING, number=1) - measurement = proto.Field(proto.MESSAGE, number=3, message=gca_study.Measurement,) + measurement = proto.Field(proto.MESSAGE, number=3, + message=gca_study.Measurement, + ) class CompleteTrialRequest(proto.Message): @@ -362,8 +380,8 @@ class CompleteTrialRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - final_measurement = proto.Field( - proto.MESSAGE, number=2, message=gca_study.Measurement, + final_measurement = proto.Field(proto.MESSAGE, number=2, + message=gca_study.Measurement, ) trial_infeasible = proto.Field(proto.BOOL, number=3) @@ -424,8 +442,8 @@ class CheckTrialEarlyStoppingStateMetatdata(proto.Message): The Trial name. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) study = proto.Field(proto.STRING, number=2) @@ -471,8 +489,8 @@ class ListOptimalTrialsResponse(proto.Message): https://en.wikipedia.org/wiki/Pareto_efficiency """ - optimal_trials = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_study.Trial, + optimal_trials = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_study.Trial, ) diff --git a/noxfile.py b/noxfile.py index 35270f664f..32bd822f2b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -18,6 +18,7 @@ from __future__ import absolute_import import os +import pathlib import shutil import nox @@ -26,9 +27,11 @@ BLACK_VERSION = "black==19.10b0" BLACK_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.8" -SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] -UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] +DEFAULT_PYTHON_VERSION="3.8" +SYSTEM_TEST_PYTHON_VERSIONS=["3.8"] +UNIT_TEST_PYTHON_VERSIONS=["3.6","3.7","3.8","3.9"] + +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() # 'docfx' is excluded since it only needs to run in 'docs-presubmit' nox.options.sessions = [ @@ -54,7 +57,9 @@ def lint(session): """ session.install("flake8", BLACK_VERSION) session.run( - "black", "--check", *BLACK_PATHS, + "black", + "--check", + *BLACK_PATHS, ) session.run("flake8", "google", "tests") @@ -71,7 +76,8 @@ def blacken(session): """ session.install(BLACK_VERSION) session.run( - "black", *BLACK_PATHS, + "black", + *BLACK_PATHS, ) @@ -84,13 +90,17 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. - session.install("asyncmock", "pytest-asyncio") - session.install( - "mock", "pytest", "pytest-cov", + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" ) - - session.install("-e", ".") + session.install("asyncmock", "pytest-asyncio", "-c", constraints_path) + + session.install("mock", "pytest", "pytest-cov", "-c", constraints_path) + + + session.install("-e", ".", "-c", constraints_path) + # Run py.test against the unit tests. session.run( @@ -107,7 +117,6 @@ def default(session): *session.posargs, ) - @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) def unit(session): """Run the unit test suite.""" @@ -117,11 +126,14 @@ def unit(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) def system(session): """Run the system test suite.""" + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) system_test_path = os.path.join("tests", "system.py") system_test_folder_path = os.path.join("tests", "system") # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. - if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": + if os.environ.get("RUN_SYSTEM_TESTS", "true") == 'false': session.skip("RUN_SYSTEM_TESTS is set to false, skipping") # Sanity check: Only run tests if the environment variable is set. if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): @@ -141,10 +153,9 @@ def system(session): # Install all test dependencies, then install this package into the # virtualenv's dist-packages. - session.install( - "mock", "pytest", "google-cloud-testutils", - ) - session.install("-e", ".") + session.install("mock", "pytest", "google-cloud-testutils", "-c", constraints_path) + session.install("-e", ".", "-c", constraints_path) + # Run py.test against the system tests. if system_test_exists: @@ -153,7 +164,7 @@ def system(session): "--quiet", f"--junitxml=system_{session.python}_sponge_log.xml", system_test_path, - *session.posargs, + *session.posargs ) if system_test_folder_exists: session.run( @@ -161,10 +172,11 @@ def system(session): "--quiet", f"--junitxml=system_{session.python}_sponge_log.xml", system_test_folder_path, - *session.posargs, + *session.posargs ) + @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. @@ -177,25 +189,23 @@ def cover(session): session.run("coverage", "erase") - @nox.session(python=DEFAULT_PYTHON_VERSION) def docs(session): """Build the docs for this library.""" - session.install("-e", ".") - session.install("sphinx", "alabaster", "recommonmark") + session.install('-e', '.') + session.install('sphinx', 'alabaster', 'recommonmark') - shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + shutil.rmtree(os.path.join('docs', '_build'), ignore_errors=True) session.run( - "sphinx-build", - "-T", # show full traceback on exception - "-N", # no colors - "-b", - "html", - "-d", - os.path.join("docs", "_build", "doctrees", ""), - os.path.join("docs", ""), - os.path.join("docs", "_build", "html", ""), + 'sphinx-build', + + '-T', # show full traceback on exception + '-N', # no colors + '-b', 'html', + '-d', os.path.join('docs', '_build', 'doctrees', ''), + os.path.join('docs', ''), + os.path.join('docs', '_build', 'html', ''), ) diff --git a/tests/unit/gapic/aiplatform_v1/__init__.py b/tests/unit/gapic/aiplatform_v1/__init__.py index 42ffdf2bc4..6a73015364 100644 --- a/tests/unit/gapic/aiplatform_v1/__init__.py +++ b/tests/unit/gapic/aiplatform_v1/__init__.py @@ -1,3 +1,4 @@ + # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py index 1597014605..118d0eefe5 100644 --- a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -35,9 +35,7 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.dataset_service import ( - DatasetServiceAsyncClient, -) +from google.cloud.aiplatform_v1.services.dataset_service import DatasetServiceAsyncClient from google.cloud.aiplatform_v1.services.dataset_service import DatasetServiceClient from google.cloud.aiplatform_v1.services.dataset_service import pagers from google.cloud.aiplatform_v1.services.dataset_service import transports @@ -65,11 +63,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -80,52 +74,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + DatasetServiceClient, + DatasetServiceAsyncClient, +]) def test_dataset_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + DatasetServiceClient, + DatasetServiceAsyncClient, +]) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -135,7 +113,7 @@ def test_dataset_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_dataset_service_client_get_transport_class(): @@ -149,44 +127,29 @@ def test_dataset_service_client_get_transport_class(): assert transport == transports.DatasetServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) -def test_dataset_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) +def test_dataset_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -202,7 +165,7 @@ def test_dataset_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -218,7 +181,7 @@ def test_dataset_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -238,15 +201,13 @@ def test_dataset_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -259,52 +220,26 @@ def test_dataset_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -327,18 +262,10 @@ def test_dataset_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -359,14 +286,9 @@ def test_dataset_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -380,23 +302,16 @@ def test_dataset_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -409,24 +324,16 @@ def test_dataset_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -441,12 +348,10 @@ def test_dataset_service_client_client_options_credentials_file( def test_dataset_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = DatasetServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -459,11 +364,10 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset( - transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest -): +def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -471,9 +375,11 @@ def test_create_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_dataset(request) @@ -495,24 +401,25 @@ def test_create_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: client.create_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.CreateDatasetRequest() - @pytest.mark.asyncio -async def test_create_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest -): +async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.CreateDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -520,10 +427,12 @@ async def test_create_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_dataset(request) @@ -544,16 +453,20 @@ async def test_create_dataset_async_from_dict(): def test_create_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_dataset(request) @@ -564,23 +477,28 @@ def test_create_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_dataset(request) @@ -591,21 +509,29 @@ async def test_create_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -613,40 +539,47 @@ def test_create_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') def test_create_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_dataset( dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) @pytest.mark.asyncio async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -654,30 +587,31 @@ async def test_create_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') @pytest.mark.asyncio async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_dataset( dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) -def test_get_dataset( - transport: str = "grpc", request_type=dataset_service.GetDatasetRequest -): +def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -685,13 +619,19 @@ def test_get_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + ) response = client.get_dataset(request) @@ -706,13 +646,13 @@ def test_get_dataset( assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_dataset_from_dict(): @@ -723,24 +663,25 @@ def test_get_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: client.get_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetDatasetRequest() - @pytest.mark.asyncio -async def test_get_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest -): +async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -748,16 +689,16 @@ async def test_get_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) response = await client.get_dataset(request) @@ -770,13 +711,13 @@ async def test_get_dataset_async( # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -785,15 +726,19 @@ async def test_get_dataset_async_from_dict(): def test_get_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: call.return_value = dataset.Dataset() client.get_dataset(request) @@ -805,20 +750,27 @@ def test_get_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) await client.get_dataset(request) @@ -830,79 +782,99 @@ async def test_get_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset(name="name_value",) + client.get_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", + dataset_service.GetDatasetRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset(name="name_value",) + response = await client.get_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", + dataset_service.GetDatasetRequest(), + name='name_value', ) -def test_update_dataset( - transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest -): +def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -910,13 +882,19 @@ def test_update_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + ) response = client.update_dataset(request) @@ -931,13 +909,13 @@ def test_update_dataset( assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_dataset_from_dict(): @@ -948,24 +926,25 @@ def test_update_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: client.update_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.UpdateDatasetRequest() - @pytest.mark.asyncio -async def test_update_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest -): +async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.UpdateDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -973,16 +952,16 @@ async def test_update_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) response = await client.update_dataset(request) @@ -995,13 +974,13 @@ async def test_update_dataset_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -1010,15 +989,19 @@ async def test_update_dataset_async_from_dict(): def test_update_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" + request.dataset.name = 'dataset.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: call.return_value = gca_dataset.Dataset() client.update_dataset(request) @@ -1030,22 +1013,27 @@ def test_update_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" + request.dataset.name = 'dataset.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) await client.update_dataset(request) @@ -1057,24 +1045,29 @@ async def test_update_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] def test_update_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1082,30 +1075,36 @@ def test_update_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() @@ -1113,8 +1112,8 @@ async def test_update_dataset_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1122,30 +1121,31 @@ async def test_update_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_list_datasets( - transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest -): +def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1153,10 +1153,13 @@ def test_list_datasets( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_datasets(request) @@ -1171,7 +1174,7 @@ def test_list_datasets( assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_datasets_from_dict(): @@ -1182,24 +1185,25 @@ def test_list_datasets_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: client.list_datasets() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDatasetsRequest() - @pytest.mark.asyncio -async def test_list_datasets_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest -): +async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDatasetsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1207,13 +1211,13 @@ async def test_list_datasets_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_datasets(request) @@ -1226,7 +1230,7 @@ async def test_list_datasets_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1235,15 +1239,19 @@ async def test_list_datasets_async_from_dict(): def test_list_datasets_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: call.return_value = dataset_service.ListDatasetsResponse() client.list_datasets(request) @@ -1255,23 +1263,28 @@ def test_list_datasets_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) await client.list_datasets(request) @@ -1282,100 +1295,138 @@ async def test_list_datasets_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_datasets_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_datasets(parent="parent_value",) + client.list_datasets( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_datasets_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", + dataset_service.ListDatasetsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_datasets(parent="parent_value",) + response = await client.list_datasets( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", + dataset_service.ListDatasetsRequest(), + parent='parent_value', ) def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_datasets(request={}) @@ -1383,102 +1434,147 @@ def test_list_datasets_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) for i in results) - + assert all(isinstance(i, dataset.Dataset) + for i in results) def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) for i in responses) - + assert all(isinstance(i, dataset.Dataset) + for i in responses) @pytest.mark.asyncio async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_datasets(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_dataset( - transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest -): +def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1486,9 +1582,11 @@ def test_delete_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_dataset(request) @@ -1510,24 +1608,25 @@ def test_delete_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: client.delete_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.DeleteDatasetRequest() - @pytest.mark.asyncio -async def test_delete_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest -): +async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.DeleteDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1535,10 +1634,12 @@ async def test_delete_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_dataset(request) @@ -1559,16 +1660,20 @@ async def test_delete_dataset_async_from_dict(): def test_delete_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_dataset(request) @@ -1579,23 +1684,28 @@ def test_delete_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_dataset(request) @@ -1606,81 +1716,101 @@ async def test_delete_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset(name="name_value",) + client.delete_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", + dataset_service.DeleteDatasetRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset(name="name_value",) + response = await client.delete_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", + dataset_service.DeleteDatasetRequest(), + name='name_value', ) -def test_import_data( - transport: str = "grpc", request_type=dataset_service.ImportDataRequest -): +def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1688,9 +1818,11 @@ def test_import_data( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.import_data(request) @@ -1712,24 +1844,25 @@ def test_import_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: client.import_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ImportDataRequest() - @pytest.mark.asyncio -async def test_import_data_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest -): +async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ImportDataRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1737,10 +1870,12 @@ async def test_import_data_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.import_data(request) @@ -1761,16 +1896,20 @@ async def test_import_data_async_from_dict(): def test_import_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.import_data(request) @@ -1781,23 +1920,28 @@ def test_import_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.import_data(request) @@ -1808,24 +1952,29 @@ async def test_import_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_import_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) # Establish that the underlying call was made with the expected @@ -1833,47 +1982,47 @@ def test_import_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] def test_import_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_data( dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) @pytest.mark.asyncio async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) # Establish that the underlying call was made with the expected @@ -1881,34 +2030,31 @@ async def test_import_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] @pytest.mark.asyncio async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.import_data( dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) -def test_export_data( - transport: str = "grpc", request_type=dataset_service.ExportDataRequest -): +def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1916,9 +2062,11 @@ def test_export_data( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.export_data(request) @@ -1940,24 +2088,25 @@ def test_export_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: client.export_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ExportDataRequest() - @pytest.mark.asyncio -async def test_export_data_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest -): +async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ExportDataRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1965,10 +2114,12 @@ async def test_export_data_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.export_data(request) @@ -1989,16 +2140,20 @@ async def test_export_data_async_from_dict(): def test_export_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.export_data(request) @@ -2009,23 +2164,28 @@ def test_export_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.export_data(request) @@ -2036,26 +2196,29 @@ async def test_export_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_export_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) # Establish that the underlying call was made with the expected @@ -2063,53 +2226,47 @@ def test_export_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) def test_export_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_data( dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) @pytest.mark.asyncio async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) # Establish that the underlying call was made with the expected @@ -2117,38 +2274,31 @@ async def test_export_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) @pytest.mark.asyncio async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_data( dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) -def test_list_data_items( - transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest -): +def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2156,10 +2306,13 @@ def test_list_data_items( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_data_items(request) @@ -2174,7 +2327,7 @@ def test_list_data_items( assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_data_items_from_dict(): @@ -2185,24 +2338,25 @@ def test_list_data_items_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: client.list_data_items() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDataItemsRequest() - @pytest.mark.asyncio -async def test_list_data_items_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest -): +async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDataItemsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2210,13 +2364,13 @@ async def test_list_data_items_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_data_items(request) @@ -2229,7 +2383,7 @@ async def test_list_data_items_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2238,15 +2392,19 @@ async def test_list_data_items_async_from_dict(): def test_list_data_items_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: call.return_value = dataset_service.ListDataItemsResponse() client.list_data_items(request) @@ -2258,23 +2416,28 @@ def test_list_data_items_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) await client.list_data_items(request) @@ -2285,81 +2448,104 @@ async def test_list_data_items_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_data_items_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items(parent="parent_value",) + client.list_data_items( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_data_items_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", + dataset_service.ListDataItemsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items(parent="parent_value",) + response = await client.list_data_items( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", + dataset_service.ListDataItemsRequest(), + parent='parent_value', ) def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2368,23 +2554,32 @@ def test_list_data_items_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_data_items(request={}) @@ -2392,14 +2587,18 @@ def test_list_data_items_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) for i in results) - + assert all(isinstance(i, data_item.DataItem) + for i in results) def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2408,32 +2607,40 @@ def test_list_data_items_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2442,37 +2649,46 @@ async def test_list_data_items_async_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) for i in responses) - + assert all(isinstance(i, data_item.DataItem) + for i in responses) @pytest.mark.asyncio async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2481,31 +2697,37 @@ async def test_list_data_items_async_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_data_items(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec( - transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest -): +def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2514,11 +2736,16 @@ def test_get_annotation_spec( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + ) response = client.get_annotation_spec(request) @@ -2533,11 +2760,11 @@ def test_get_annotation_spec( assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_annotation_spec_from_dict(): @@ -2548,27 +2775,25 @@ def test_get_annotation_spec_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: client.get_annotation_spec() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetAnnotationSpecRequest() - @pytest.mark.asyncio -async def test_get_annotation_spec_async( - transport: str = "grpc_asyncio", - request_type=dataset_service.GetAnnotationSpecRequest, -): +async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetAnnotationSpecRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2577,14 +2802,14 @@ async def test_get_annotation_spec_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( + name='name_value', + display_name='display_name_value', + etag='etag_value', + )) response = await client.get_annotation_spec(request) @@ -2597,11 +2822,11 @@ async def test_get_annotation_spec_async( # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -2610,17 +2835,19 @@ async def test_get_annotation_spec_async_from_dict(): def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: call.return_value = annotation_spec.AnnotationSpec() client.get_annotation_spec(request) @@ -2632,25 +2859,28 @@ def test_get_annotation_spec_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) + type(client.transport.get_annotation_spec), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) await client.get_annotation_spec(request) @@ -2661,85 +2891,99 @@ async def test_get_annotation_spec_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_annotation_spec_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_annotation_spec(name="name_value",) + client.get_annotation_spec( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", + dataset_service.GetAnnotationSpecRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_annotation_spec(name="name_value",) + response = await client.get_annotation_spec( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", + dataset_service.GetAnnotationSpecRequest(), + name='name_value', ) -def test_list_annotations( - transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest -): +def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2747,10 +2991,13 @@ def test_list_annotations( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_annotations(request) @@ -2765,7 +3012,7 @@ def test_list_annotations( assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_annotations_from_dict(): @@ -2776,24 +3023,25 @@ def test_list_annotations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: client.list_annotations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListAnnotationsRequest() - @pytest.mark.asyncio -async def test_list_annotations_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest -): +async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListAnnotationsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2801,13 +3049,13 @@ async def test_list_annotations_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_annotations(request) @@ -2820,7 +3068,7 @@ async def test_list_annotations_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2829,15 +3077,19 @@ async def test_list_annotations_async_from_dict(): def test_list_annotations_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: call.return_value = dataset_service.ListAnnotationsResponse() client.list_annotations(request) @@ -2849,23 +3101,28 @@ def test_list_annotations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) await client.list_annotations(request) @@ -2876,81 +3133,104 @@ async def test_list_annotations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_annotations_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations(parent="parent_value",) + client.list_annotations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_annotations_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", + dataset_service.ListAnnotationsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations(parent="parent_value",) + response = await client.list_annotations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", + dataset_service.ListAnnotationsRequest(), + parent='parent_value', ) def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2959,23 +3239,32 @@ def test_list_annotations_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_annotations(request={}) @@ -2983,14 +3272,18 @@ def test_list_annotations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) for i in results) - + assert all(isinstance(i, annotation.Annotation) + for i in results) def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2999,32 +3292,40 @@ def test_list_annotations_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3033,37 +3334,46 @@ async def test_list_annotations_async_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) for i in responses) - + assert all(isinstance(i, annotation.Annotation) + for i in responses) @pytest.mark.asyncio async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3072,23 +3382,30 @@ async def test_list_annotations_async_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_annotations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token @@ -3099,7 +3416,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3118,7 +3436,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -3146,16 +3465,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3163,8 +3479,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.DatasetServiceGrpcTransport, + ) def test_dataset_service_base_transport_error(): @@ -3172,15 +3493,13 @@ def test_dataset_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_dataset_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3189,17 +3508,17 @@ def test_dataset_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_dataset", - "get_dataset", - "update_dataset", - "list_datasets", - "delete_dataset", - "import_data", - "export_data", - "list_data_items", - "get_annotation_spec", - "list_annotations", - ) + 'create_dataset', + 'get_dataset', + 'update_dataset', + 'list_datasets', + 'delete_dataset', + 'import_data', + 'export_data', + 'list_data_items', + 'get_annotation_spec', + 'list_annotations', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3212,28 +3531,23 @@ def test_dataset_service_base_transport(): def test_dataset_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_dataset_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport() @@ -3242,11 +3556,11 @@ def test_dataset_service_base_transport_with_adc(): def test_dataset_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) DatasetServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -3254,25 +3568,19 @@ def test_dataset_service_auth_adc(): def test_dataset_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +def test_dataset_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3281,13 +3589,15 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_cl transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3302,40 +3612,38 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_cl with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_dataset_service_host_with_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_dataset_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3343,11 +3651,12 @@ def test_dataset_service_grpc_transport_channel(): def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3356,22 +3665,12 @@ def test_dataset_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3380,7 +3679,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3396,7 +3695,9 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3410,23 +3711,17 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_dataset_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +def test_dataset_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3443,7 +3738,9 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3456,12 +3753,16 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): def test_dataset_service_grpc_lro_client(): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3469,12 +3770,16 @@ def test_dataset_service_grpc_lro_client(): def test_dataset_service_grpc_lro_async_client(): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3487,26 +3792,19 @@ def test_annotation_path(): data_item = "octopus" annotation = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( - project=project, - location=location, - dataset=dataset, - data_item=data_item, - annotation=annotation, - ) - actual = DatasetServiceClient.annotation_path( - project, location, dataset, data_item, annotation - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) assert expected == actual def test_parse_annotation_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - "data_item": "winkle", - "annotation": "nautilus", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", + } path = DatasetServiceClient.annotation_path(**expected) @@ -3514,31 +3812,24 @@ def test_parse_annotation_path(): actual = DatasetServiceClient.parse_annotation_path(path) assert expected == actual - def test_annotation_spec_path(): project = "scallop" location = "abalone" dataset = "squid" annotation_spec = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( - project=project, - location=location, - dataset=dataset, - annotation_spec=annotation_spec, - ) - actual = DatasetServiceClient.annotation_spec_path( - project, location, dataset, annotation_spec - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) assert expected == actual def test_parse_annotation_spec_path(): expected = { - "project": "whelk", - "location": "octopus", - "dataset": "oyster", - "annotation_spec": "nudibranch", + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", + } path = DatasetServiceClient.annotation_spec_path(**expected) @@ -3546,26 +3837,24 @@ def test_parse_annotation_spec_path(): actual = DatasetServiceClient.parse_annotation_spec_path(path) assert expected == actual - def test_data_item_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" data_item = "nautilus" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( - project=project, location=location, dataset=dataset, data_item=data_item, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) assert expected == actual def test_parse_data_item_path(): expected = { - "project": "scallop", - "location": "abalone", - "dataset": "squid", - "data_item": "clam", + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", + } path = DatasetServiceClient.data_item_path(**expected) @@ -3573,24 +3862,22 @@ def test_parse_data_item_path(): actual = DatasetServiceClient.parse_data_item_path(path) assert expected == actual - def test_dataset_path(): project = "whelk" location = "octopus" dataset = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = DatasetServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + } path = DatasetServiceClient.dataset_path(**expected) @@ -3598,20 +3885,18 @@ def test_parse_dataset_path(): actual = DatasetServiceClient.parse_dataset_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = DatasetServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "nautilus", + } path = DatasetServiceClient.common_billing_account_path(**expected) @@ -3619,18 +3904,18 @@ def test_parse_common_billing_account_path(): actual = DatasetServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = DatasetServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "abalone", + } path = DatasetServiceClient.common_folder_path(**expected) @@ -3638,18 +3923,18 @@ def test_parse_common_folder_path(): actual = DatasetServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = DatasetServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "clam", + } path = DatasetServiceClient.common_organization_path(**expected) @@ -3657,18 +3942,18 @@ def test_parse_common_organization_path(): actual = DatasetServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = DatasetServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "octopus", + } path = DatasetServiceClient.common_project_path(**expected) @@ -3676,22 +3961,20 @@ def test_parse_common_project_path(): actual = DatasetServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = DatasetServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "cuttlefish", + "location": "mussel", + } path = DatasetServiceClient.common_location_path(**expected) @@ -3703,19 +3986,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: transport_class = DatasetServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index bf351a3978..b2ae6bd168 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -35,9 +35,7 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.endpoint_service import ( - EndpointServiceAsyncClient, -) +from google.cloud.aiplatform_v1.services.endpoint_service import EndpointServiceAsyncClient from google.cloud.aiplatform_v1.services.endpoint_service import EndpointServiceClient from google.cloud.aiplatform_v1.services.endpoint_service import pagers from google.cloud.aiplatform_v1.services.endpoint_service import transports @@ -62,11 +60,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -77,52 +71,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + EndpointServiceClient, + EndpointServiceAsyncClient, +]) def test_endpoint_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + EndpointServiceClient, + EndpointServiceAsyncClient, +]) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -132,7 +110,7 @@ def test_endpoint_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_endpoint_service_client_get_transport_class(): @@ -146,44 +124,29 @@ def test_endpoint_service_client_get_transport_class(): assert transport == transports.EndpointServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) -def test_endpoint_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) +def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -199,7 +162,7 @@ def test_endpoint_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -215,7 +178,7 @@ def test_endpoint_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -235,15 +198,13 @@ def test_endpoint_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -256,62 +217,26 @@ def test_endpoint_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "true", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "false", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -334,18 +259,10 @@ def test_endpoint_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -366,14 +283,9 @@ def test_endpoint_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -387,23 +299,16 @@ def test_endpoint_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -416,24 +321,16 @@ def test_endpoint_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -448,12 +345,10 @@ def test_endpoint_service_client_client_options_credentials_file( def test_endpoint_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = EndpointServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -466,11 +361,10 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint( - transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest -): +def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -478,9 +372,11 @@ def test_create_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_endpoint(request) @@ -502,24 +398,25 @@ def test_create_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: client.create_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.CreateEndpointRequest() - @pytest.mark.asyncio -async def test_create_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest -): +async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.CreateEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -527,10 +424,12 @@ async def test_create_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_endpoint(request) @@ -551,16 +450,20 @@ async def test_create_endpoint_async_from_dict(): def test_create_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_endpoint(request) @@ -571,23 +474,28 @@ def test_create_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_endpoint(request) @@ -598,21 +506,29 @@ async def test_create_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -620,40 +536,47 @@ def test_create_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') def test_create_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) @pytest.mark.asyncio async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -661,30 +584,31 @@ async def test_create_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') @pytest.mark.asyncio async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) -def test_get_endpoint( - transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest -): +def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -692,13 +616,19 @@ def test_get_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + ) response = client.get_endpoint(request) @@ -713,13 +643,13 @@ def test_get_endpoint( assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_endpoint_from_dict(): @@ -730,24 +660,25 @@ def test_get_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: client.get_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.GetEndpointRequest() - @pytest.mark.asyncio -async def test_get_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest -): +async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.GetEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -755,16 +686,16 @@ async def test_get_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) response = await client.get_endpoint(request) @@ -777,13 +708,13 @@ async def test_get_endpoint_async( # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -792,15 +723,19 @@ async def test_get_endpoint_async_from_dict(): def test_get_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: call.return_value = endpoint.Endpoint() client.get_endpoint(request) @@ -812,20 +747,27 @@ def test_get_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) await client.get_endpoint(request) @@ -837,79 +779,99 @@ async def test_get_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_endpoint(name="name_value",) + client.get_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", + endpoint_service.GetEndpointRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_endpoint(name="name_value",) + response = await client.get_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", + endpoint_service.GetEndpointRequest(), + name='name_value', ) -def test_list_endpoints( - transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest -): +def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -917,10 +879,13 @@ def test_list_endpoints( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_endpoints(request) @@ -935,7 +900,7 @@ def test_list_endpoints( assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_endpoints_from_dict(): @@ -946,24 +911,25 @@ def test_list_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: client.list_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.ListEndpointsRequest() - @pytest.mark.asyncio -async def test_list_endpoints_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest -): +async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.ListEndpointsRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -971,13 +937,13 @@ async def test_list_endpoints_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_endpoints(request) @@ -990,7 +956,7 @@ async def test_list_endpoints_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -999,15 +965,19 @@ async def test_list_endpoints_async_from_dict(): def test_list_endpoints_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: call.return_value = endpoint_service.ListEndpointsResponse() client.list_endpoints(request) @@ -1019,23 +989,28 @@ def test_list_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) await client.list_endpoints(request) @@ -1046,81 +1021,104 @@ async def test_list_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_endpoints_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_endpoints(parent="parent_value",) + client.list_endpoints( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_endpoints_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", + endpoint_service.ListEndpointsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_endpoints(parent="parent_value",) + response = await client.list_endpoints( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", + endpoint_service.ListEndpointsRequest(), + parent='parent_value', ) def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1129,23 +1127,32 @@ def test_list_endpoints_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_endpoints(request={}) @@ -1153,14 +1160,18 @@ def test_list_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in results) - + assert all(isinstance(i, endpoint.Endpoint) + for i in results) def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1169,32 +1180,40 @@ def test_list_endpoints_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1203,37 +1222,46 @@ async def test_list_endpoints_async_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in responses) - + assert all(isinstance(i, endpoint.Endpoint) + for i in responses) @pytest.mark.asyncio async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1242,31 +1270,37 @@ async def test_list_endpoints_async_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_update_endpoint( - transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest -): +def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1274,13 +1308,19 @@ def test_update_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + ) response = client.update_endpoint(request) @@ -1295,13 +1335,13 @@ def test_update_endpoint( assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_endpoint_from_dict(): @@ -1312,24 +1352,25 @@ def test_update_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: client.update_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UpdateEndpointRequest() - @pytest.mark.asyncio -async def test_update_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest -): +async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UpdateEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1337,16 +1378,16 @@ async def test_update_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) response = await client.update_endpoint(request) @@ -1359,13 +1400,13 @@ async def test_update_endpoint_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -1374,15 +1415,19 @@ async def test_update_endpoint_async_from_dict(): def test_update_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" + request.endpoint.name = 'endpoint.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: call.return_value = gca_endpoint.Endpoint() client.update_endpoint(request) @@ -1394,25 +1439,28 @@ def test_update_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" + request.endpoint.name = 'endpoint.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) await client.update_endpoint(request) @@ -1423,24 +1471,29 @@ async def test_update_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] def test_update_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1448,41 +1501,45 @@ def test_update_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1490,30 +1547,31 @@ async def test_update_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_delete_endpoint( - transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest -): +def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1521,9 +1579,11 @@ def test_delete_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_endpoint(request) @@ -1545,24 +1605,25 @@ def test_delete_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: client.delete_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeleteEndpointRequest() - @pytest.mark.asyncio -async def test_delete_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest -): +async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeleteEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1570,10 +1631,12 @@ async def test_delete_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_endpoint(request) @@ -1594,16 +1657,20 @@ async def test_delete_endpoint_async_from_dict(): def test_delete_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_endpoint(request) @@ -1614,23 +1681,28 @@ def test_delete_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_endpoint(request) @@ -1641,81 +1713,101 @@ async def test_delete_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_endpoint(name="name_value",) + client.delete_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", + endpoint_service.DeleteEndpointRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_endpoint(name="name_value",) + response = await client.delete_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", + endpoint_service.DeleteEndpointRequest(), + name='name_value', ) -def test_deploy_model( - transport: str = "grpc", request_type=endpoint_service.DeployModelRequest -): +def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1723,9 +1815,11 @@ def test_deploy_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.deploy_model(request) @@ -1747,24 +1841,25 @@ def test_deploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: client.deploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeployModelRequest() - @pytest.mark.asyncio -async def test_deploy_model_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest -): +async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeployModelRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1772,10 +1867,12 @@ async def test_deploy_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.deploy_model(request) @@ -1796,16 +1893,20 @@ async def test_deploy_model_async_from_dict(): def test_deploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.deploy_model(request) @@ -1816,23 +1917,28 @@ def test_deploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.deploy_model(request) @@ -1843,29 +1949,30 @@ async def test_deploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] def test_deploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -1873,63 +1980,51 @@ def test_deploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} def test_deploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) @pytest.mark.asyncio async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -1937,45 +2032,34 @@ async def test_deploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} @pytest.mark.asyncio async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) -def test_undeploy_model( - transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest -): +def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1983,9 +2067,11 @@ def test_undeploy_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.undeploy_model(request) @@ -2007,24 +2093,25 @@ def test_undeploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: client.undeploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UndeployModelRequest() - @pytest.mark.asyncio -async def test_undeploy_model_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest -): +async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UndeployModelRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2032,10 +2119,12 @@ async def test_undeploy_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.undeploy_model(request) @@ -2056,16 +2145,20 @@ async def test_undeploy_model_async_from_dict(): def test_undeploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.undeploy_model(request) @@ -2076,23 +2169,28 @@ def test_undeploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.undeploy_model(request) @@ -2103,23 +2201,30 @@ async def test_undeploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] def test_undeploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -2127,45 +2232,51 @@ def test_undeploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model_id == "deployed_model_id_value" + assert args[0].deployed_model_id == 'deployed_model_id_value' - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} def test_undeploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) @pytest.mark.asyncio async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -2173,25 +2284,27 @@ async def test_undeploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model_id == "deployed_model_id_value" + assert args[0].deployed_model_id == 'deployed_model_id_value' - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} @pytest.mark.asyncio async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) @@ -2202,7 +2315,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2221,7 +2335,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -2249,16 +2364,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2266,8 +2378,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.EndpointServiceGrpcTransport, + ) def test_endpoint_service_base_transport_error(): @@ -2275,15 +2392,13 @@ def test_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2292,14 +2407,14 @@ def test_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_endpoint", - "get_endpoint", - "list_endpoints", - "update_endpoint", - "delete_endpoint", - "deploy_model", - "undeploy_model", - ) + 'create_endpoint', + 'get_endpoint', + 'list_endpoints', + 'update_endpoint', + 'delete_endpoint', + 'deploy_model', + 'undeploy_model', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2312,28 +2427,23 @@ def test_endpoint_service_base_transport(): def test_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport() @@ -2342,11 +2452,11 @@ def test_endpoint_service_base_transport_with_adc(): def test_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) EndpointServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -2354,25 +2464,19 @@ def test_endpoint_service_auth_adc(): def test_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -2381,13 +2485,15 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_c transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2402,40 +2508,38 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_c with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_endpoint_service_host_with_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_endpoint_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2443,11 +2547,12 @@ def test_endpoint_service_grpc_transport_channel(): def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2456,22 +2561,12 @@ def test_endpoint_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2480,7 +2575,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2496,7 +2591,9 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2510,23 +2607,17 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +def test_endpoint_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2543,7 +2634,9 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2556,12 +2649,16 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): def test_endpoint_service_grpc_lro_client(): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2569,12 +2666,16 @@ def test_endpoint_service_grpc_lro_client(): def test_endpoint_service_grpc_lro_async_client(): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2585,18 +2686,17 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = EndpointServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = EndpointServiceClient.endpoint_path(**expected) @@ -2604,24 +2704,22 @@ def test_parse_endpoint_path(): actual = EndpointServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = EndpointServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = EndpointServiceClient.model_path(**expected) @@ -2629,20 +2727,18 @@ def test_parse_model_path(): actual = EndpointServiceClient.parse_model_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = EndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "clam", + } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2650,18 +2746,18 @@ def test_parse_common_billing_account_path(): actual = EndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "octopus", + } path = EndpointServiceClient.common_folder_path(**expected) @@ -2669,18 +2765,18 @@ def test_parse_common_folder_path(): actual = EndpointServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nudibranch", + } path = EndpointServiceClient.common_organization_path(**expected) @@ -2688,18 +2784,18 @@ def test_parse_common_organization_path(): actual = EndpointServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = EndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "mussel", + } path = EndpointServiceClient.common_project_path(**expected) @@ -2707,22 +2803,20 @@ def test_parse_common_project_path(): actual = EndpointServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = EndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "scallop", + "location": "abalone", + } path = EndpointServiceClient.common_location_path(**expected) @@ -2734,19 +2828,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: transport_class = EndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index 50d1339247..c6acd32ec8 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -41,9 +41,7 @@ from google.cloud.aiplatform_v1.services.job_service import transports from google.cloud.aiplatform_v1.types import accelerator_type from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1.types import completion_stats from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job @@ -52,9 +50,7 @@ from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import env_var from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import job_service from google.cloud.aiplatform_v1.types import job_state @@ -81,11 +77,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -96,49 +88,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert ( - JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) + assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [JobServiceClient, JobServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + JobServiceClient, + JobServiceAsyncClient, +]) def test_job_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [JobServiceClient, JobServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + JobServiceClient, + JobServiceAsyncClient, +]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -148,7 +127,7 @@ def test_job_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_job_service_client_get_transport_class(): @@ -162,42 +141,29 @@ def test_job_service_client_get_transport_class(): assert transport == transports.JobServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) -def test_job_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) +def test_job_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -213,7 +179,7 @@ def test_job_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -229,7 +195,7 @@ def test_job_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -249,15 +215,13 @@ def test_job_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -270,50 +234,26 @@ def test_job_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -336,18 +276,10 @@ def test_job_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -368,14 +300,9 @@ def test_job_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -389,23 +316,16 @@ def test_job_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -418,24 +338,16 @@ def test_job_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -450,11 +362,11 @@ def test_job_service_client_client_options_credentials_file( def test_job_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + client = JobServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -466,11 +378,10 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job( - transport: str = "grpc", request_type=job_service.CreateCustomJobRequest -): +def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -479,13 +390,16 @@ def test_create_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_custom_job(request) @@ -500,9 +414,9 @@ def test_create_custom_job( assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -515,26 +429,25 @@ def test_create_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: client.create_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateCustomJobRequest() - @pytest.mark.asyncio -async def test_create_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest -): +async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -543,16 +456,14 @@ async def test_create_custom_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_custom_job(request) @@ -565,9 +476,9 @@ async def test_create_custom_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -578,17 +489,19 @@ async def test_create_custom_job_async_from_dict(): def test_create_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: call.return_value = gca_custom_job.CustomJob() client.create_custom_job(request) @@ -600,25 +513,28 @@ def test_create_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) + type(client.transport.create_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) await client.create_custom_job(request) @@ -629,24 +545,29 @@ async def test_create_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -654,43 +575,45 @@ def test_create_custom_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') def test_create_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_custom_job( job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -698,30 +621,31 @@ async def test_create_custom_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') @pytest.mark.asyncio async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_custom_job( job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) -def test_get_custom_job( - transport: str = "grpc", request_type=job_service.GetCustomJobRequest -): +def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -729,12 +653,17 @@ def test_get_custom_job( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_custom_job(request) @@ -749,9 +678,9 @@ def test_get_custom_job( assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -764,24 +693,25 @@ def test_get_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: client.get_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetCustomJobRequest() - @pytest.mark.asyncio -async def test_get_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest -): +async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -789,15 +719,15 @@ async def test_get_custom_job_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_custom_job(request) @@ -810,9 +740,9 @@ async def test_get_custom_job_async( # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -823,15 +753,19 @@ async def test_get_custom_job_async_from_dict(): def test_get_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: call.return_value = custom_job.CustomJob() client.get_custom_job(request) @@ -843,23 +777,28 @@ def test_get_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) await client.get_custom_job(request) @@ -870,81 +809,99 @@ async def test_get_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_custom_job(name="name_value",) + client.get_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", + job_service.GetCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_custom_job(name="name_value",) + response = await client.get_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", + job_service.GetCustomJobRequest(), + name='name_value', ) -def test_list_custom_jobs( - transport: str = "grpc", request_type=job_service.ListCustomJobsRequest -): +def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -952,10 +909,13 @@ def test_list_custom_jobs( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_custom_jobs(request) @@ -970,7 +930,7 @@ def test_list_custom_jobs( assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_custom_jobs_from_dict(): @@ -981,24 +941,25 @@ def test_list_custom_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: client.list_custom_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListCustomJobsRequest() - @pytest.mark.asyncio -async def test_list_custom_jobs_async( - transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest -): +async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListCustomJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1006,11 +967,13 @@ async def test_list_custom_jobs_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_custom_jobs(request) @@ -1023,7 +986,7 @@ async def test_list_custom_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1032,15 +995,19 @@ async def test_list_custom_jobs_async_from_dict(): def test_list_custom_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: call.return_value = job_service.ListCustomJobsResponse() client.list_custom_jobs(request) @@ -1052,23 +1019,28 @@ def test_list_custom_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) await client.list_custom_jobs(request) @@ -1079,81 +1051,104 @@ async def test_list_custom_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_custom_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_custom_jobs(parent="parent_value",) + client.list_custom_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_custom_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", + job_service.ListCustomJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_custom_jobs(parent="parent_value",) + response = await client.list_custom_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", + job_service.ListCustomJobsRequest(), + parent='parent_value', ) def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1162,21 +1157,32 @@ def test_list_custom_jobs_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[], + next_page_token='def', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_custom_jobs(request={}) @@ -1184,14 +1190,18 @@ def test_list_custom_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in results) - + assert all(isinstance(i, custom_job.CustomJob) + for i in results) def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1200,30 +1210,40 @@ def test_list_custom_jobs_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1232,35 +1252,46 @@ async def test_list_custom_jobs_async_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[], + next_page_token='def', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in responses) - + assert all(isinstance(i, custom_job.CustomJob) + for i in responses) @pytest.mark.asyncio async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1269,29 +1300,37 @@ async def test_list_custom_jobs_async_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_custom_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_custom_job( - transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest -): +def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1300,10 +1339,10 @@ def test_delete_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_custom_job(request) @@ -1325,26 +1364,25 @@ def test_delete_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: client.delete_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteCustomJobRequest() - @pytest.mark.asyncio -async def test_delete_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest -): +async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1353,11 +1391,11 @@ async def test_delete_custom_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_custom_job(request) @@ -1378,18 +1416,20 @@ async def test_delete_custom_job_async_from_dict(): def test_delete_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_custom_job(request) @@ -1400,25 +1440,28 @@ def test_delete_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_custom_job(request) @@ -1429,85 +1472,101 @@ async def test_delete_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_custom_job(name="name_value",) + client.delete_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", + job_service.DeleteCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_custom_job(name="name_value",) + response = await client.delete_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", + job_service.DeleteCustomJobRequest(), + name='name_value', ) -def test_cancel_custom_job( - transport: str = "grpc", request_type=job_service.CancelCustomJobRequest -): +def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1516,8 +1575,8 @@ def test_cancel_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1541,26 +1600,25 @@ def test_cancel_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: client.cancel_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelCustomJobRequest() - @pytest.mark.asyncio -async def test_cancel_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest -): +async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1569,8 +1627,8 @@ async def test_cancel_custom_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1592,17 +1650,19 @@ async def test_cancel_custom_job_async_from_dict(): def test_cancel_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: call.return_value = None client.cancel_custom_job(request) @@ -1614,22 +1674,27 @@ def test_cancel_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_custom_job(request) @@ -1641,83 +1706,99 @@ async def test_cancel_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_custom_job(name="name_value",) + client.cancel_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", + job_service.CancelCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_custom_job(name="name_value",) + response = await client.cancel_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", + job_service.CancelCustomJobRequest(), + name='name_value', ) -def test_create_data_labeling_job( - transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest -): +def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1726,19 +1807,28 @@ def test_create_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, - specialist_pools=["specialist_pools_value"], + + specialist_pools=['specialist_pools_value'], + ) response = client.create_data_labeling_job(request) @@ -1753,23 +1843,23 @@ def test_create_data_labeling_job( assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] def test_create_data_labeling_job_from_dict(): @@ -1780,27 +1870,25 @@ def test_create_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: client.create_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_create_data_labeling_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CreateDataLabelingJobRequest, -): +async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1809,22 +1897,20 @@ async def test_create_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) response = await client.create_data_labeling_job(request) @@ -1837,23 +1923,23 @@ async def test_create_data_labeling_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] @pytest.mark.asyncio @@ -1862,17 +1948,19 @@ async def test_create_data_labeling_job_async_from_dict(): def test_create_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: call.return_value = gca_data_labeling_job.DataLabelingJob() client.create_data_labeling_job(request) @@ -1884,25 +1972,28 @@ def test_create_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) + type(client.transport.create_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) await client.create_data_labeling_job(request) @@ -1913,24 +2004,29 @@ async def test_create_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -1938,45 +2034,45 @@ def test_create_data_labeling_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -1984,32 +2080,31 @@ async def test_create_data_labeling_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) -def test_get_data_labeling_job( - transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest -): +def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2018,19 +2113,28 @@ def test_get_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, - specialist_pools=["specialist_pools_value"], + + specialist_pools=['specialist_pools_value'], + ) response = client.get_data_labeling_job(request) @@ -2045,23 +2149,23 @@ def test_get_data_labeling_job( assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] def test_get_data_labeling_job_from_dict(): @@ -2072,26 +2176,25 @@ def test_get_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: client.get_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_get_data_labeling_job_async( - transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest -): +async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2100,22 +2203,20 @@ async def test_get_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) response = await client.get_data_labeling_job(request) @@ -2128,23 +2229,23 @@ async def test_get_data_labeling_job_async( # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] @pytest.mark.asyncio @@ -2153,17 +2254,19 @@ async def test_get_data_labeling_job_async_from_dict(): def test_get_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: call.return_value = data_labeling_job.DataLabelingJob() client.get_data_labeling_job(request) @@ -2175,25 +2278,28 @@ def test_get_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) + type(client.transport.get_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) await client.get_data_labeling_job(request) @@ -2204,85 +2310,99 @@ async def test_get_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_data_labeling_job(name="name_value",) + client.get_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", + job_service.GetDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_data_labeling_job(name="name_value",) + response = await client.get_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", + job_service.GetDataLabelingJobRequest(), + name='name_value', ) -def test_list_data_labeling_jobs( - transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest -): +def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2291,11 +2411,12 @@ def test_list_data_labeling_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_data_labeling_jobs(request) @@ -2310,7 +2431,7 @@ def test_list_data_labeling_jobs( assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_data_labeling_jobs_from_dict(): @@ -2321,27 +2442,25 @@ def test_list_data_labeling_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: client.list_data_labeling_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListDataLabelingJobsRequest() - @pytest.mark.asyncio -async def test_list_data_labeling_jobs_async( - transport: str = "grpc_asyncio", - request_type=job_service.ListDataLabelingJobsRequest, -): +async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListDataLabelingJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2350,14 +2469,12 @@ async def test_list_data_labeling_jobs_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_data_labeling_jobs(request) @@ -2370,7 +2487,7 @@ async def test_list_data_labeling_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2379,17 +2496,19 @@ async def test_list_data_labeling_jobs_async_from_dict(): def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: call.return_value = job_service.ListDataLabelingJobsResponse() client.list_data_labeling_jobs(request) @@ -2401,25 +2520,28 @@ def test_list_data_labeling_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) await client.list_data_labeling_jobs(request) @@ -2430,87 +2552,104 @@ async def test_list_data_labeling_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_labeling_jobs(parent="parent_value",) + client.list_data_labeling_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs(parent="parent_value",) + response = await client.list_data_labeling_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', ) def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2519,14 +2658,17 @@ def test_list_data_labeling_jobs_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2539,7 +2681,9 @@ def test_list_data_labeling_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_data_labeling_jobs(request={}) @@ -2547,16 +2691,18 @@ def test_list_data_labeling_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) - + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in results) def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2565,14 +2711,17 @@ def test_list_data_labeling_jobs_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2583,20 +2732,19 @@ def test_list_data_labeling_jobs_pages(): RuntimeError, ) pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2605,14 +2753,17 @@ async def test_list_data_labeling_jobs_async_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2623,25 +2774,25 @@ async def test_list_data_labeling_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) - + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in responses) @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2650,14 +2801,17 @@ async def test_list_data_labeling_jobs_async_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2670,15 +2824,14 @@ async def test_list_data_labeling_jobs_async_pages(): pages = [] async for page_ in (await client.list_data_labeling_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job( - transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest -): +def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2687,10 +2840,10 @@ def test_delete_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_data_labeling_job(request) @@ -2712,27 +2865,25 @@ def test_delete_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: client.delete_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_delete_data_labeling_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.DeleteDataLabelingJobRequest, -): +async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2741,11 +2892,11 @@ async def test_delete_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_data_labeling_job(request) @@ -2766,18 +2917,20 @@ async def test_delete_data_labeling_job_async_from_dict(): def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_data_labeling_job(request) @@ -2788,25 +2941,28 @@ def test_delete_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_data_labeling_job(request) @@ -2817,85 +2973,101 @@ async def test_delete_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_data_labeling_job(name="name_value",) + client.delete_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", + job_service.DeleteDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_data_labeling_job(name="name_value",) + response = await client.delete_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", + job_service.DeleteDataLabelingJobRequest(), + name='name_value', ) -def test_cancel_data_labeling_job( - transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest -): +def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2904,8 +3076,8 @@ def test_cancel_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2929,27 +3101,25 @@ def test_cancel_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: client.cancel_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_cancel_data_labeling_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CancelDataLabelingJobRequest, -): +async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2958,8 +3128,8 @@ async def test_cancel_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -2981,17 +3151,19 @@ async def test_cancel_data_labeling_job_async_from_dict(): def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: call.return_value = None client.cancel_data_labeling_job(request) @@ -3003,22 +3175,27 @@ def test_cancel_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_data_labeling_job(request) @@ -3030,84 +3207,99 @@ async def test_cancel_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_data_labeling_job(name="name_value",) + client.cancel_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", + job_service.CancelDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job(name="name_value",) + response = await client.cancel_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", + job_service.CancelDataLabelingJobRequest(), + name='name_value', ) -def test_create_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): +def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3116,16 +3308,22 @@ def test_create_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_hyperparameter_tuning_job(request) @@ -3140,9 +3338,9 @@ def test_create_hyperparameter_tuning_job( assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3161,27 +3359,25 @@ def test_create_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: client.create_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): +async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3190,19 +3386,17 @@ async def test_create_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_hyperparameter_tuning_job(request) @@ -3215,9 +3409,9 @@ async def test_create_hyperparameter_tuning_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3234,17 +3428,19 @@ async def test_create_hyperparameter_tuning_job_async_from_dict(): def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() client.create_hyperparameter_tuning_job(request) @@ -3256,25 +3452,28 @@ def test_create_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) await client.create_hyperparameter_tuning_job(request) @@ -3285,26 +3484,29 @@ async def test_create_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -3312,51 +3514,45 @@ def test_create_hyperparameter_tuning_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -3364,36 +3560,31 @@ async def test_create_hyperparameter_tuning_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) -def test_get_hyperparameter_tuning_job( - transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest -): +def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3402,16 +3593,22 @@ def test_get_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_hyperparameter_tuning_job(request) @@ -3426,9 +3623,9 @@ def test_get_hyperparameter_tuning_job( assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3447,27 +3644,25 @@ def test_get_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: client.get_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.GetHyperparameterTuningJobRequest, -): +async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3476,19 +3671,17 @@ async def test_get_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_hyperparameter_tuning_job(request) @@ -3501,9 +3694,9 @@ async def test_get_hyperparameter_tuning_job_async( # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3520,17 +3713,19 @@ async def test_get_hyperparameter_tuning_job_async_from_dict(): def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() client.get_hyperparameter_tuning_job(request) @@ -3542,25 +3737,28 @@ def test_get_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) await client.get_hyperparameter_tuning_job(request) @@ -3571,86 +3769,99 @@ async def test_get_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job(name="name_value",) + client.get_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job(name="name_value",) + response = await client.get_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', ) -def test_list_hyperparameter_tuning_jobs( - transport: str = "grpc", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): +def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3659,11 +3870,12 @@ def test_list_hyperparameter_tuning_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_hyperparameter_tuning_jobs(request) @@ -3678,7 +3890,7 @@ def test_list_hyperparameter_tuning_jobs( assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_hyperparameter_tuning_jobs_from_dict(): @@ -3689,27 +3901,25 @@ def test_list_hyperparameter_tuning_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: client.list_hyperparameter_tuning_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListHyperparameterTuningJobsRequest() - @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async( - transport: str = "grpc_asyncio", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): +async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListHyperparameterTuningJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3718,14 +3928,12 @@ async def test_list_hyperparameter_tuning_jobs_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_hyperparameter_tuning_jobs(request) @@ -3738,7 +3946,7 @@ async def test_list_hyperparameter_tuning_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -3747,17 +3955,19 @@ async def test_list_hyperparameter_tuning_jobs_async_from_dict(): def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: call.return_value = job_service.ListHyperparameterTuningJobsResponse() client.list_hyperparameter_tuning_jobs(request) @@ -3769,25 +3979,28 @@ def test_list_hyperparameter_tuning_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) await client.list_hyperparameter_tuning_jobs(request) @@ -3798,87 +4011,104 @@ async def test_list_hyperparameter_tuning_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs(parent="parent_value",) + client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) + response = await client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', ) def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3887,16 +4117,17 @@ def test_list_hyperparameter_tuning_jobs_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3909,7 +4140,9 @@ def test_list_hyperparameter_tuning_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_hyperparameter_tuning_jobs(request={}) @@ -3917,19 +4150,18 @@ def test_list_hyperparameter_tuning_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results - ) - + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results) def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3938,16 +4170,17 @@ def test_list_hyperparameter_tuning_jobs_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3958,20 +4191,19 @@ def test_list_hyperparameter_tuning_jobs_pages(): RuntimeError, ) pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3980,16 +4212,17 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4000,28 +4233,25 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses - ) - + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4030,16 +4260,17 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4050,20 +4281,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( - await client.list_hyperparameter_tuning_jobs(request={}) - ).pages: + async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): +def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4072,10 +4299,10 @@ def test_delete_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_hyperparameter_tuning_job(request) @@ -4097,27 +4324,25 @@ def test_delete_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: client.delete_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): +async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4126,11 +4351,11 @@ async def test_delete_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_hyperparameter_tuning_job(request) @@ -4151,18 +4376,20 @@ async def test_delete_hyperparameter_tuning_job_async_from_dict(): def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_hyperparameter_tuning_job(request) @@ -4173,25 +4400,28 @@ def test_delete_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_hyperparameter_tuning_job(request) @@ -4202,86 +4432,101 @@ async def test_delete_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job(name="name_value",) + client.delete_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job(name="name_value",) + response = await client.delete_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', ) -def test_cancel_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): +def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4290,8 +4535,8 @@ def test_cancel_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -4315,27 +4560,25 @@ def test_cancel_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: client.cancel_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): +async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4344,8 +4587,8 @@ async def test_cancel_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -4367,17 +4610,19 @@ async def test_cancel_hyperparameter_tuning_job_async_from_dict(): def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: call.return_value = None client.cancel_hyperparameter_tuning_job(request) @@ -4389,22 +4634,27 @@ def test_cancel_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_hyperparameter_tuning_job(request) @@ -4416,83 +4666,99 @@ async def test_cancel_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job(name="name_value",) + client.cancel_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job(name="name_value",) + response = await client.cancel_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', ) -def test_create_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest -): +def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4501,14 +4767,18 @@ def test_create_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", + name='name_value', + + display_name='display_name_value', + + model='model_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_batch_prediction_job(request) @@ -4523,11 +4793,11 @@ def test_create_batch_prediction_job( assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -4540,27 +4810,25 @@ def test_create_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: client.create_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateBatchPredictionJobRequest() - @pytest.mark.asyncio -async def test_create_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CreateBatchPredictionJobRequest, -): +async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4569,17 +4837,15 @@ async def test_create_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_batch_prediction_job(request) @@ -4592,11 +4858,11 @@ async def test_create_batch_prediction_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -4607,17 +4873,19 @@ async def test_create_batch_prediction_job_async_from_dict(): def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: call.return_value = gca_batch_prediction_job.BatchPredictionJob() client.create_batch_prediction_job(request) @@ -4629,25 +4897,28 @@ def test_create_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) + type(client.transport.create_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) await client.create_batch_prediction_job(request) @@ -4658,26 +4929,29 @@ async def test_create_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -4685,51 +4959,45 @@ def test_create_batch_prediction_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -4737,36 +5005,31 @@ async def test_create_batch_prediction_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) -def test_get_batch_prediction_job( - transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest -): +def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4775,14 +5038,18 @@ def test_get_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", + name='name_value', + + display_name='display_name_value', + + model='model_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_batch_prediction_job(request) @@ -4797,11 +5064,11 @@ def test_get_batch_prediction_job( assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -4814,27 +5081,25 @@ def test_get_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: client.get_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetBatchPredictionJobRequest() - @pytest.mark.asyncio -async def test_get_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.GetBatchPredictionJobRequest, -): +async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4843,17 +5108,15 @@ async def test_get_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_batch_prediction_job(request) @@ -4866,11 +5129,11 @@ async def test_get_batch_prediction_job_async( # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -4881,17 +5144,19 @@ async def test_get_batch_prediction_job_async_from_dict(): def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: call.return_value = batch_prediction_job.BatchPredictionJob() client.get_batch_prediction_job(request) @@ -4903,25 +5168,28 @@ def test_get_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) + type(client.transport.get_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) await client.get_batch_prediction_job(request) @@ -4932,85 +5200,99 @@ async def test_get_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_batch_prediction_job(name="name_value",) + client.get_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", + job_service.GetBatchPredictionJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_batch_prediction_job(name="name_value",) + response = await client.get_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", + job_service.GetBatchPredictionJobRequest(), + name='name_value', ) -def test_list_batch_prediction_jobs( - transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest -): +def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5019,11 +5301,12 @@ def test_list_batch_prediction_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_batch_prediction_jobs(request) @@ -5038,7 +5321,7 @@ def test_list_batch_prediction_jobs( assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_batch_prediction_jobs_from_dict(): @@ -5049,27 +5332,25 @@ def test_list_batch_prediction_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: client.list_batch_prediction_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListBatchPredictionJobsRequest() - @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async( - transport: str = "grpc_asyncio", - request_type=job_service.ListBatchPredictionJobsRequest, -): +async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListBatchPredictionJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5078,14 +5359,12 @@ async def test_list_batch_prediction_jobs_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_batch_prediction_jobs(request) @@ -5098,7 +5377,7 @@ async def test_list_batch_prediction_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -5107,17 +5386,19 @@ async def test_list_batch_prediction_jobs_async_from_dict(): def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: call.return_value = job_service.ListBatchPredictionJobsResponse() client.list_batch_prediction_jobs(request) @@ -5129,25 +5410,28 @@ def test_list_batch_prediction_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) await client.list_batch_prediction_jobs(request) @@ -5158,87 +5442,104 @@ async def test_list_batch_prediction_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_batch_prediction_jobs(parent="parent_value",) + client.list_batch_prediction_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs(parent="parent_value",) + response = await client.list_batch_prediction_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', ) def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5247,14 +5548,17 @@ def test_list_batch_prediction_jobs_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5267,7 +5571,9 @@ def test_list_batch_prediction_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_batch_prediction_jobs(request={}) @@ -5275,18 +5581,18 @@ def test_list_batch_prediction_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results - ) - + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in results) def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5295,14 +5601,17 @@ def test_list_batch_prediction_jobs_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5313,20 +5622,19 @@ def test_list_batch_prediction_jobs_pages(): RuntimeError, ) pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5335,14 +5643,17 @@ async def test_list_batch_prediction_jobs_async_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5353,27 +5664,25 @@ async def test_list_batch_prediction_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses - ) - + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in responses) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5382,14 +5691,17 @@ async def test_list_batch_prediction_jobs_async_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5402,15 +5714,14 @@ async def test_list_batch_prediction_jobs_async_pages(): pages = [] async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job( - transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest -): +def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5419,10 +5730,10 @@ def test_delete_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_batch_prediction_job(request) @@ -5444,27 +5755,25 @@ def test_delete_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: client.delete_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteBatchPredictionJobRequest() - @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.DeleteBatchPredictionJobRequest, -): +async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5473,11 +5782,11 @@ async def test_delete_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_batch_prediction_job(request) @@ -5498,18 +5807,20 @@ async def test_delete_batch_prediction_job_async_from_dict(): def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_batch_prediction_job(request) @@ -5520,25 +5831,28 @@ def test_delete_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_batch_prediction_job(request) @@ -5549,85 +5863,101 @@ async def test_delete_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_batch_prediction_job(name="name_value",) + client.delete_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job(name="name_value",) + response = await client.delete_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', ) -def test_cancel_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest -): +def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5636,8 +5966,8 @@ def test_cancel_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -5661,27 +5991,25 @@ def test_cancel_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: client.cancel_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelBatchPredictionJobRequest() - @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CancelBatchPredictionJobRequest, -): +async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5690,8 +6018,8 @@ async def test_cancel_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -5713,17 +6041,19 @@ async def test_cancel_batch_prediction_job_async_from_dict(): def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: call.return_value = None client.cancel_batch_prediction_job(request) @@ -5735,22 +6065,27 @@ def test_cancel_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_batch_prediction_job(request) @@ -5762,75 +6097,92 @@ async def test_cancel_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_batch_prediction_job(name="name_value",) + client.cancel_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", + job_service.CancelBatchPredictionJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job(name="name_value",) + response = await client.cancel_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", + job_service.CancelBatchPredictionJobRequest(), + name='name_value', ) @@ -5841,7 +6193,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -5860,7 +6213,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -5888,13 +6242,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport,], -) +@pytest.mark.parametrize("transport_class", [ + transports.JobServiceGrpcTransport, + transports.JobServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -5902,8 +6256,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.JobServiceGrpcTransport,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.JobServiceGrpcTransport, + ) def test_job_service_base_transport_error(): @@ -5911,15 +6270,13 @@ def test_job_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_job_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -5928,27 +6285,27 @@ def test_job_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_custom_job", - "get_custom_job", - "list_custom_jobs", - "delete_custom_job", - "cancel_custom_job", - "create_data_labeling_job", - "get_data_labeling_job", - "list_data_labeling_jobs", - "delete_data_labeling_job", - "cancel_data_labeling_job", - "create_hyperparameter_tuning_job", - "get_hyperparameter_tuning_job", - "list_hyperparameter_tuning_jobs", - "delete_hyperparameter_tuning_job", - "cancel_hyperparameter_tuning_job", - "create_batch_prediction_job", - "get_batch_prediction_job", - "list_batch_prediction_jobs", - "delete_batch_prediction_job", - "cancel_batch_prediction_job", - ) + 'create_custom_job', + 'get_custom_job', + 'list_custom_jobs', + 'delete_custom_job', + 'cancel_custom_job', + 'create_data_labeling_job', + 'get_data_labeling_job', + 'list_data_labeling_jobs', + 'delete_data_labeling_job', + 'cancel_data_labeling_job', + 'create_hyperparameter_tuning_job', + 'get_hyperparameter_tuning_job', + 'list_hyperparameter_tuning_jobs', + 'delete_hyperparameter_tuning_job', + 'cancel_hyperparameter_tuning_job', + 'create_batch_prediction_job', + 'get_batch_prediction_job', + 'list_batch_prediction_jobs', + 'delete_batch_prediction_job', + 'cancel_batch_prediction_job', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -5961,28 +6318,23 @@ def test_job_service_base_transport(): def test_job_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_job_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport() @@ -5991,11 +6343,11 @@ def test_job_service_base_transport_with_adc(): def test_job_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) JobServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -6003,22 +6355,19 @@ def test_job_service_auth_adc(): def test_job_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -6027,13 +6376,15 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class) transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -6048,40 +6399,38 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class) with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_job_service_host_with_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_job_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6089,11 +6438,12 @@ def test_job_service_grpc_transport_channel(): def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6102,17 +6452,12 @@ def test_job_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -6121,7 +6466,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -6137,7 +6482,9 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6151,20 +6498,17 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -6181,7 +6525,9 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6194,12 +6540,16 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): def test_job_service_grpc_lro_client(): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6207,12 +6557,16 @@ def test_job_service_grpc_lro_client(): def test_job_service_grpc_lro_async_client(): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6223,20 +6577,17 @@ def test_batch_prediction_job_path(): location = "clam" batch_prediction_job = "whelk" - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, location=location, batch_prediction_job=batch_prediction_job, - ) - actual = JobServiceClient.batch_prediction_job_path( - project, location, batch_prediction_job - ) + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) assert expected == actual def test_parse_batch_prediction_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", + } path = JobServiceClient.batch_prediction_job_path(**expected) @@ -6244,24 +6595,22 @@ def test_parse_batch_prediction_job_path(): actual = JobServiceClient.parse_batch_prediction_job_path(path) assert expected == actual - def test_custom_job_path(): project = "cuttlefish" location = "mussel" custom_job = "winkle" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) actual = JobServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", + } path = JobServiceClient.custom_job_path(**expected) @@ -6269,26 +6618,22 @@ def test_parse_custom_job_path(): actual = JobServiceClient.parse_custom_job_path(path) assert expected == actual - def test_data_labeling_job_path(): project = "squid" location = "clam" data_labeling_job = "whelk" - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) - actual = JobServiceClient.data_labeling_job_path( - project, location, data_labeling_job - ) + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) assert expected == actual def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", + } path = JobServiceClient.data_labeling_job_path(**expected) @@ -6296,24 +6641,22 @@ def test_parse_data_labeling_job_path(): actual = JobServiceClient.parse_data_labeling_job_path(path) assert expected == actual - def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = JobServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } path = JobServiceClient.dataset_path(**expected) @@ -6321,28 +6664,22 @@ def test_parse_dataset_path(): actual = JobServiceClient.parse_dataset_path(path) assert expected == actual - def test_hyperparameter_tuning_job_path(): project = "squid" location = "clam" hyperparameter_tuning_job = "whelk" - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) - actual = JobServiceClient.hyperparameter_tuning_job_path( - project, location, hyperparameter_tuning_job - ) + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) assert expected == actual def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "hyperparameter_tuning_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", + } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -6350,24 +6687,22 @@ def test_parse_hyperparameter_tuning_job_path(): actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = JobServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = JobServiceClient.model_path(**expected) @@ -6375,26 +6710,24 @@ def test_parse_model_path(): actual = JobServiceClient.parse_model_path(path) assert expected == actual - def test_trial_path(): project = "squid" location = "clam" study = "whelk" trial = "octopus" - expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( - project=project, location=location, study=study, trial=trial, - ) + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) actual = JobServiceClient.trial_path(project, location, study, trial) assert expected == actual def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", + } path = JobServiceClient.trial_path(**expected) @@ -6402,20 +6735,18 @@ def test_parse_trial_path(): actual = JobServiceClient.parse_trial_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = JobServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "nautilus", + } path = JobServiceClient.common_billing_account_path(**expected) @@ -6423,18 +6754,18 @@ def test_parse_common_billing_account_path(): actual = JobServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = JobServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "abalone", + } path = JobServiceClient.common_folder_path(**expected) @@ -6442,18 +6773,18 @@ def test_parse_common_folder_path(): actual = JobServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = JobServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "clam", + } path = JobServiceClient.common_organization_path(**expected) @@ -6461,18 +6792,18 @@ def test_parse_common_organization_path(): actual = JobServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = JobServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "octopus", + } path = JobServiceClient.common_project_path(**expected) @@ -6480,22 +6811,20 @@ def test_parse_common_project_path(): actual = JobServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = JobServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "cuttlefish", + "location": "mussel", + } path = JobServiceClient.common_location_path(**expected) @@ -6507,19 +6836,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: transport_class = JobServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 04bc7c392a..2f1c62f3ef 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -35,9 +35,7 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.migration_service import ( - MigrationServiceAsyncClient, -) +from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceAsyncClient from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceClient from google.cloud.aiplatform_v1.services.migration_service import pagers from google.cloud.aiplatform_v1.services.migration_service import transports @@ -55,11 +53,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -70,53 +64,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert ( - MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) + assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + MigrationServiceClient, + MigrationServiceAsyncClient, +]) def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + MigrationServiceClient, + MigrationServiceAsyncClient, +]) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -126,7 +103,7 @@ def test_migration_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_migration_service_client_get_transport_class(): @@ -140,44 +117,29 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - MigrationServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceClient), -) -@mock.patch.object( - MigrationServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceAsyncClient), -) -def test_migration_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +def test_migration_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -193,7 +155,7 @@ def test_migration_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -209,7 +171,7 @@ def test_migration_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -229,15 +191,13 @@ def test_migration_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -250,62 +210,26 @@ def test_migration_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - MigrationServiceClient, - transports.MigrationServiceGrpcTransport, - "grpc", - "true", - ), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - MigrationServiceClient, - transports.MigrationServiceGrpcTransport, - "grpc", - "false", - ), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - MigrationServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceClient), -) -@mock.patch.object( - MigrationServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceAsyncClient), -) + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -328,18 +252,10 @@ def test_migration_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -360,14 +276,9 @@ def test_migration_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -381,23 +292,16 @@ def test_migration_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_migration_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -410,24 +314,16 @@ def test_migration_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_migration_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -442,12 +338,10 @@ def test_migration_service_client_client_options_credentials_file( def test_migration_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -460,12 +354,10 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources( - transport: str = "grpc", - request_type=migration_service.SearchMigratableResourcesRequest, -): +def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -474,11 +366,12 @@ def test_search_migratable_resources( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.search_migratable_resources(request) @@ -493,7 +386,7 @@ def test_search_migratable_resources( assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_search_migratable_resources_from_dict(): @@ -504,27 +397,25 @@ def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.SearchMigratableResourcesRequest() - @pytest.mark.asyncio -async def test_search_migratable_resources_async( - transport: str = "grpc_asyncio", - request_type=migration_service.SearchMigratableResourcesRequest, -): +async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -533,14 +424,12 @@ async def test_search_migratable_resources_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( + next_page_token='next_page_token_value', + )) response = await client.search_migratable_resources(request) @@ -553,7 +442,7 @@ async def test_search_migratable_resources_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -562,17 +451,19 @@ async def test_search_migratable_resources_async_from_dict(): def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -584,7 +475,10 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -596,15 +490,13 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse() - ) + type(client.transport.search_migratable_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) await client.search_migratable_resources(request) @@ -615,39 +507,49 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources(parent="parent_value",) + client.search_migratable_resources( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', ) @@ -659,24 +561,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources(parent="parent_value",) + response = await client.search_migratable_resources( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio @@ -689,17 +591,20 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -708,14 +613,17 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -728,7 +636,9 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.search_migratable_resources(request={}) @@ -736,18 +646,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, migratable_resource.MigratableResource) for i in results - ) - + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in results) def test_search_migratable_resources_pages(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -756,14 +666,17 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -774,20 +687,19 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -796,14 +708,17 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -814,27 +729,25 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, migratable_resource.MigratableResource) for i in responses - ) - + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in responses) @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -843,14 +756,17 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -863,15 +779,14 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources( - transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest -): +def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -880,10 +795,10 @@ def test_batch_migrate_resources( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.batch_migrate_resources(request) @@ -905,27 +820,25 @@ def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.BatchMigrateResourcesRequest() - @pytest.mark.asyncio -async def test_batch_migrate_resources_async( - transport: str = "grpc_asyncio", - request_type=migration_service.BatchMigrateResourcesRequest, -): +async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -934,11 +847,11 @@ async def test_batch_migrate_resources_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.batch_migrate_resources(request) @@ -959,18 +872,20 @@ async def test_batch_migrate_resources_async_from_dict(): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.batch_migrate_resources(request) @@ -981,7 +896,10 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -993,15 +911,13 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.batch_migrate_resources(request) @@ -1012,30 +928,29 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) # Establish that the underlying call was made with the expected @@ -1043,33 +958,23 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].migrate_resource_requests == [ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ] + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) @@ -1081,25 +986,19 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) # Establish that the underlying call was made with the expected @@ -1107,15 +1006,9 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].migrate_resource_requests == [ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ] + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] @pytest.mark.asyncio @@ -1129,14 +1022,8 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) @@ -1147,7 +1034,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1166,7 +1054,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1194,16 +1083,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1211,8 +1097,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.MigrationServiceGrpcTransport, + ) def test_migration_service_base_transport_error(): @@ -1220,15 +1111,13 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1237,9 +1126,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "search_migratable_resources", - "batch_migrate_resources", - ) + 'search_migratable_resources', + 'batch_migrate_resources', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1252,28 +1141,23 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1282,11 +1166,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -1294,25 +1178,19 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) -def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +def test_migration_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1321,13 +1199,15 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_ transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1342,40 +1222,38 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_ with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_migration_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1383,11 +1261,12 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1396,22 +1275,12 @@ def test_migration_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1420,7 +1289,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1436,7 +1305,9 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1450,23 +1321,17 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) -def test_migration_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +def test_migration_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1483,7 +1348,9 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1496,12 +1363,16 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1509,12 +1380,16 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1525,20 +1400,17 @@ def test_annotated_dataset_path(): dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( - project=project, dataset=dataset, annotated_dataset=annotated_dataset, - ) - actual = MigrationServiceClient.annotated_dataset_path( - project, dataset, annotated_dataset - ) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", + } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1546,24 +1418,22 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual - def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1571,24 +1441,20 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_dataset_path(): project = "squid" - location = "clam" - dataset = "whelk" + dataset = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", + "project": "whelk", + "dataset": "octopus", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1596,22 +1462,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_dataset_path(): - project = "cuttlefish" - dataset = "mussel" + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, - ) - actual = MigrationServiceClient.dataset_path(project, dataset) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1619,24 +1485,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", + "project": "clam", + "location": "whelk", + "model": "octopus", + } path = MigrationServiceClient.model_path(**expected) @@ -1644,24 +1508,22 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual - def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", + "project": "mussel", + "location": "winkle", + "model": "nautilus", + } path = MigrationServiceClient.model_path(**expected) @@ -1669,24 +1531,22 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual - def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format( - project=project, model=model, version=version, - ) + expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", + "project": "clam", + "model": "whelk", + "version": "octopus", + } path = MigrationServiceClient.version_path(**expected) @@ -1694,20 +1554,18 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", + "billing_account": "nudibranch", + } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1715,18 +1573,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", + "folder": "mussel", + } path = MigrationServiceClient.common_folder_path(**expected) @@ -1734,18 +1592,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", + "organization": "nautilus", + } path = MigrationServiceClient.common_organization_path(**expected) @@ -1753,18 +1611,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", + "project": "abalone", + } path = MigrationServiceClient.common_project_path(**expected) @@ -1772,22 +1630,20 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", + "project": "whelk", + "location": "octopus", + } path = MigrationServiceClient.common_location_path(**expected) @@ -1799,19 +1655,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.MigrationServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.MigrationServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index 15e4bad05d..0011bd1129 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -64,11 +64,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -79,49 +75,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) + assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [ModelServiceClient, ModelServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + ModelServiceClient, + ModelServiceAsyncClient, +]) def test_model_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [ModelServiceClient, ModelServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + ModelServiceClient, + ModelServiceAsyncClient, +]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -131,7 +114,7 @@ def test_model_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_model_service_client_get_transport_class(): @@ -145,42 +128,29 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) -def test_model_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -196,7 +166,7 @@ def test_model_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -212,7 +182,7 @@ def test_model_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -232,15 +202,13 @@ def test_model_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -253,50 +221,26 @@ def test_model_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -319,18 +263,10 @@ def test_model_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -351,14 +287,9 @@ def test_model_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -372,23 +303,16 @@ def test_model_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -401,24 +325,16 @@ def test_model_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -433,11 +349,11 @@ def test_model_service_client_client_options_credentials_file( def test_model_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + client = ModelServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -449,11 +365,10 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model( - transport: str = "grpc", request_type=model_service.UploadModelRequest -): +def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -461,9 +376,11 @@ def test_upload_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.upload_model(request) @@ -485,24 +402,25 @@ def test_upload_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: client.upload_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UploadModelRequest() - @pytest.mark.asyncio -async def test_upload_model_async( - transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest -): +async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UploadModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -510,10 +428,12 @@ async def test_upload_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.upload_model(request) @@ -534,16 +454,20 @@ async def test_upload_model_async_from_dict(): def test_upload_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.upload_model(request) @@ -554,23 +478,28 @@ def test_upload_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.upload_model(request) @@ -581,21 +510,29 @@ async def test_upload_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_upload_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -603,40 +540,47 @@ def test_upload_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') def test_upload_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.upload_model( model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) @pytest.mark.asyncio async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -644,28 +588,31 @@ async def test_upload_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') @pytest.mark.asyncio async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.upload_model( model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) -def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): +def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -673,21 +620,31 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + ) response = client.get_model(request) @@ -702,31 +659,25 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR assert isinstance(response, model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_model_from_dict(): @@ -737,24 +688,25 @@ def test_get_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: client.get_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelRequest() - @pytest.mark.asyncio -async def test_get_model_async( - transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest -): +async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -762,28 +714,22 @@ async def test_get_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) response = await client.get_model(request) @@ -796,31 +742,25 @@ async def test_get_model_async( # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -829,15 +769,19 @@ async def test_get_model_async_from_dict(): def test_get_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: call.return_value = model.Model() client.get_model(request) @@ -849,20 +793,27 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -874,79 +825,99 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model(name="name_value",) + client.get_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model( - model_service.GetModelRequest(), name="name_value", + model_service.GetModelRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model(name="name_value",) + response = await client.get_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model( - model_service.GetModelRequest(), name="name_value", + model_service.GetModelRequest(), + name='name_value', ) -def test_list_models( - transport: str = "grpc", request_type=model_service.ListModelsRequest -): +def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -954,10 +925,13 @@ def test_list_models( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_models(request) @@ -972,7 +946,7 @@ def test_list_models( assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_models_from_dict(): @@ -983,24 +957,25 @@ def test_list_models_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: client.list_models() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelsRequest() - @pytest.mark.asyncio -async def test_list_models_async( - transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest -): +async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1008,11 +983,13 @@ async def test_list_models_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_models(request) @@ -1025,7 +1002,7 @@ async def test_list_models_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1034,15 +1011,19 @@ async def test_list_models_async_from_dict(): def test_list_models_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: call.return_value = model_service.ListModelsResponse() client.list_models(request) @@ -1054,23 +1035,28 @@ def test_list_models_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) await client.list_models(request) @@ -1081,98 +1067,138 @@ async def test_list_models_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_models_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_models(parent="parent_value",) + client.list_models( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_models_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_models( - model_service.ListModelsRequest(), parent="parent_value", + model_service.ListModelsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_models(parent="parent_value",) + response = await client.list_models( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_models( - model_service.ListModelsRequest(), parent="parent_value", + model_service.ListModelsRequest(), + parent='parent_value', ) def test_list_models_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_models(request={}) @@ -1180,96 +1206,147 @@ def test_list_models_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model.Model) for i in results) - + assert all(isinstance(i, model.Model) + for i in results) def test_list_models_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_models_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) for i in responses) - + assert all(isinstance(i, model.Model) + for i in responses) @pytest.mark.asyncio async def test_list_models_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_models(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_update_model( - transport: str = "grpc", request_type=model_service.UpdateModelRequest -): +def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1277,21 +1354,31 @@ def test_update_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + ) response = client.update_model(request) @@ -1306,31 +1393,25 @@ def test_update_model( assert isinstance(response, gca_model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_model_from_dict(): @@ -1341,24 +1422,25 @@ def test_update_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: client.update_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UpdateModelRequest() - @pytest.mark.asyncio -async def test_update_model_async( - transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest -): +async def test_update_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1366,28 +1448,22 @@ async def test_update_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) response = await client.update_model(request) @@ -1400,31 +1476,25 @@ async def test_update_model_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -1433,15 +1503,19 @@ async def test_update_model_async_from_dict(): def test_update_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" + request.model.name = 'model.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: call.return_value = gca_model.Model() client.update_model(request) @@ -1453,20 +1527,27 @@ def test_update_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" + request.model.name = 'model.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) await client.update_model(request) @@ -1478,22 +1559,29 @@ async def test_update_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] def test_update_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1501,30 +1589,36 @@ def test_update_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() @@ -1532,8 +1626,8 @@ async def test_update_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1541,30 +1635,31 @@ async def test_update_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_delete_model( - transport: str = "grpc", request_type=model_service.DeleteModelRequest -): +def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1572,9 +1667,11 @@ def test_delete_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_model(request) @@ -1596,24 +1693,25 @@ def test_delete_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: client.delete_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.DeleteModelRequest() - @pytest.mark.asyncio -async def test_delete_model_async( - transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest -): +async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1621,10 +1719,12 @@ async def test_delete_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_model(request) @@ -1645,16 +1745,20 @@ async def test_delete_model_async_from_dict(): def test_delete_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_model(request) @@ -1665,23 +1769,28 @@ def test_delete_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_model(request) @@ -1692,81 +1801,101 @@ async def test_delete_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model(name="name_value",) + client.delete_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model( - model_service.DeleteModelRequest(), name="name_value", + model_service.DeleteModelRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_model(name="name_value",) + response = await client.delete_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model( - model_service.DeleteModelRequest(), name="name_value", + model_service.DeleteModelRequest(), + name='name_value', ) -def test_export_model( - transport: str = "grpc", request_type=model_service.ExportModelRequest -): +def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1774,9 +1903,11 @@ def test_export_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.export_model(request) @@ -1798,24 +1929,25 @@ def test_export_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: client.export_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ExportModelRequest() - @pytest.mark.asyncio -async def test_export_model_async( - transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest -): +async def test_export_model_async(transport: str = 'grpc_asyncio', request_type=model_service.ExportModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1823,10 +1955,12 @@ async def test_export_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.export_model(request) @@ -1847,16 +1981,20 @@ async def test_export_model_async_from_dict(): def test_export_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.export_model(request) @@ -1867,23 +2005,28 @@ def test_export_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.export_model(request) @@ -1894,24 +2037,29 @@ async def test_export_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_export_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) # Establish that the underlying call was made with the expected @@ -1919,47 +2067,47 @@ def test_export_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') def test_export_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_model( model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) @pytest.mark.asyncio async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) # Establish that the underlying call was made with the expected @@ -1967,34 +2115,31 @@ async def test_export_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') @pytest.mark.asyncio async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_model( model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) -def test_get_model_evaluation( - transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest -): +def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2003,13 +2148,16 @@ def test_get_model_evaluation( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + + slice_dimensions=['slice_dimensions_value'], + ) response = client.get_model_evaluation(request) @@ -2024,11 +2172,11 @@ def test_get_model_evaluation( assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' - assert response.slice_dimensions == ["slice_dimensions_value"] + assert response.slice_dimensions == ['slice_dimensions_value'] def test_get_model_evaluation_from_dict(): @@ -2039,27 +2187,25 @@ def test_get_model_evaluation_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: client.get_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationRequest() - @pytest.mark.asyncio -async def test_get_model_evaluation_async( - transport: str = "grpc_asyncio", - request_type=model_service.GetModelEvaluationRequest, -): +async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2068,16 +2214,14 @@ async def test_get_model_evaluation_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + slice_dimensions=['slice_dimensions_value'], + )) response = await client.get_model_evaluation(request) @@ -2090,11 +2234,11 @@ async def test_get_model_evaluation_async( # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' - assert response.slice_dimensions == ["slice_dimensions_value"] + assert response.slice_dimensions == ['slice_dimensions_value'] @pytest.mark.asyncio @@ -2103,17 +2247,19 @@ async def test_get_model_evaluation_async_from_dict(): def test_get_model_evaluation_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: call.return_value = model_evaluation.ModelEvaluation() client.get_model_evaluation(request) @@ -2125,25 +2271,28 @@ def test_get_model_evaluation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) + type(client.transport.get_model_evaluation), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) await client.get_model_evaluation(request) @@ -2154,85 +2303,99 @@ async def test_get_model_evaluation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_evaluation_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation(name="name_value",) + client.get_model_evaluation( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", + model_service.GetModelEvaluationRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation(name="name_value",) + response = await client.get_model_evaluation( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", + model_service.GetModelEvaluationRequest(), + name='name_value', ) -def test_list_model_evaluations( - transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest -): +def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2241,11 +2404,12 @@ def test_list_model_evaluations( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_model_evaluations(request) @@ -2260,7 +2424,7 @@ def test_list_model_evaluations( assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_model_evaluations_from_dict(): @@ -2271,27 +2435,25 @@ def test_list_model_evaluations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: client.list_model_evaluations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationsRequest() - @pytest.mark.asyncio -async def test_list_model_evaluations_async( - transport: str = "grpc_asyncio", - request_type=model_service.ListModelEvaluationsRequest, -): +async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationsRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2300,14 +2462,12 @@ async def test_list_model_evaluations_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_model_evaluations(request) @@ -2320,7 +2480,7 @@ async def test_list_model_evaluations_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2329,17 +2489,19 @@ async def test_list_model_evaluations_async_from_dict(): def test_list_model_evaluations_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: call.return_value = model_service.ListModelEvaluationsResponse() client.list_model_evaluations(request) @@ -2351,25 +2513,28 @@ def test_list_model_evaluations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) + type(client.transport.list_model_evaluations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) await client.list_model_evaluations(request) @@ -2380,87 +2545,104 @@ async def test_list_model_evaluations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_model_evaluations_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluations(parent="parent_value",) + client.list_model_evaluations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", + model_service.ListModelEvaluationsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluations(parent="parent_value",) + response = await client.list_model_evaluations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", + model_service.ListModelEvaluationsRequest(), + parent='parent_value', ) def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2469,14 +2651,17 @@ def test_list_model_evaluations_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2489,7 +2674,9 @@ def test_list_model_evaluations_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_model_evaluations(request={}) @@ -2497,16 +2684,18 @@ def test_list_model_evaluations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) - + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in results) def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2515,14 +2704,17 @@ def test_list_model_evaluations_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2533,20 +2725,19 @@ def test_list_model_evaluations_pages(): RuntimeError, ) pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2555,14 +2746,17 @@ async def test_list_model_evaluations_async_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2573,25 +2767,25 @@ async def test_list_model_evaluations_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) - + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in responses) @pytest.mark.asyncio async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2600,14 +2794,17 @@ async def test_list_model_evaluations_async_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2620,15 +2817,14 @@ async def test_list_model_evaluations_async_pages(): pages = [] async for page_ in (await client.list_model_evaluations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice( - transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest -): +def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2637,11 +2833,14 @@ def test_get_model_evaluation_slice( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + ) response = client.get_model_evaluation_slice(request) @@ -2656,9 +2855,9 @@ def test_get_model_evaluation_slice( assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' def test_get_model_evaluation_slice_from_dict(): @@ -2669,27 +2868,25 @@ def test_get_model_evaluation_slice_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: client.get_model_evaluation_slice() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationSliceRequest() - @pytest.mark.asyncio -async def test_get_model_evaluation_slice_async( - transport: str = "grpc_asyncio", - request_type=model_service.GetModelEvaluationSliceRequest, -): +async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationSliceRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2698,14 +2895,13 @@ async def test_get_model_evaluation_slice_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + )) response = await client.get_model_evaluation_slice(request) @@ -2718,9 +2914,9 @@ async def test_get_model_evaluation_slice_async( # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' @pytest.mark.asyncio @@ -2729,17 +2925,19 @@ async def test_get_model_evaluation_slice_async_from_dict(): def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: call.return_value = model_evaluation_slice.ModelEvaluationSlice() client.get_model_evaluation_slice(request) @@ -2751,25 +2949,28 @@ def test_get_model_evaluation_slice_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) await client.get_model_evaluation_slice(request) @@ -2780,85 +2981,99 @@ async def test_get_model_evaluation_slice_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation_slice(name="name_value",) + client.get_model_evaluation_slice( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", + model_service.GetModelEvaluationSliceRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice(name="name_value",) + response = await client.get_model_evaluation_slice( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", + model_service.GetModelEvaluationSliceRequest(), + name='name_value', ) -def test_list_model_evaluation_slices( - transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest -): +def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2867,11 +3082,12 @@ def test_list_model_evaluation_slices( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_model_evaluation_slices(request) @@ -2886,7 +3102,7 @@ def test_list_model_evaluation_slices( assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_model_evaluation_slices_from_dict(): @@ -2897,27 +3113,25 @@ def test_list_model_evaluation_slices_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: client.list_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationSlicesRequest() - @pytest.mark.asyncio -async def test_list_model_evaluation_slices_async( - transport: str = "grpc_asyncio", - request_type=model_service.ListModelEvaluationSlicesRequest, -): +async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationSlicesRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2926,14 +3140,12 @@ async def test_list_model_evaluation_slices_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( + next_page_token='next_page_token_value', + )) response = await client.list_model_evaluation_slices(request) @@ -2946,7 +3158,7 @@ async def test_list_model_evaluation_slices_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2955,17 +3167,19 @@ async def test_list_model_evaluation_slices_async_from_dict(): def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: call.return_value = model_service.ListModelEvaluationSlicesResponse() client.list_model_evaluation_slices(request) @@ -2977,25 +3191,28 @@ def test_list_model_evaluation_slices_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) await client.list_model_evaluation_slices(request) @@ -3006,87 +3223,104 @@ async def test_list_model_evaluation_slices_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluation_slices(parent="parent_value",) + client.list_model_evaluation_slices( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices(parent="parent_value",) + response = await client.list_model_evaluation_slices( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', ) def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3095,16 +3329,17 @@ def test_list_model_evaluation_slices_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3117,7 +3352,9 @@ def test_list_model_evaluation_slices_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_model_evaluation_slices(request={}) @@ -3125,18 +3362,18 @@ def test_list_model_evaluation_slices_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results - ) - + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in results) def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3145,16 +3382,17 @@ def test_list_model_evaluation_slices_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3165,20 +3403,19 @@ def test_list_model_evaluation_slices_pages(): RuntimeError, ) pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3187,16 +3424,17 @@ async def test_list_model_evaluation_slices_async_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3207,28 +3445,25 @@ async def test_list_model_evaluation_slices_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses - ) - + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses) @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3237,16 +3472,17 @@ async def test_list_model_evaluation_slices_async_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3257,11 +3493,9 @@ async def test_list_model_evaluation_slices_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( - await client.list_model_evaluation_slices(request={}) - ).pages: + async for page_ in (await client.list_model_evaluation_slices(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token @@ -3272,7 +3506,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3291,7 +3526,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -3319,16 +3555,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3336,8 +3569,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelServiceGrpcTransport, + ) def test_model_service_base_transport_error(): @@ -3345,15 +3583,13 @@ def test_model_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3362,17 +3598,17 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "upload_model", - "get_model", - "list_models", - "update_model", - "delete_model", - "export_model", - "get_model_evaluation", - "list_model_evaluations", - "get_model_evaluation_slice", - "list_model_evaluation_slices", - ) + 'upload_model', + 'get_model', + 'list_models', + 'update_model', + 'delete_model', + 'export_model', + 'get_model_evaluation', + 'list_model_evaluations', + 'get_model_evaluation_slice', + 'list_model_evaluation_slices', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3385,28 +3621,23 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -3415,11 +3646,11 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) ModelServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -3427,22 +3658,19 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3451,13 +3679,15 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_clas transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3472,40 +3702,38 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_clas with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_model_service_host_with_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_model_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3513,11 +3741,12 @@ def test_model_service_grpc_transport_channel(): def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3526,17 +3755,12 @@ def test_model_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3545,7 +3769,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3561,7 +3785,9 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3575,20 +3801,17 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3605,7 +3828,9 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3618,12 +3843,16 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): def test_model_service_grpc_lro_client(): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3631,12 +3860,16 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3647,18 +3880,17 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = ModelServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = ModelServiceClient.endpoint_path(**expected) @@ -3666,24 +3898,22 @@ def test_parse_endpoint_path(): actual = ModelServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = ModelServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = ModelServiceClient.model_path(**expected) @@ -3691,28 +3921,24 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual - def test_model_evaluation_path(): project = "squid" location = "clam" model = "whelk" evaluation = "octopus" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( - project=project, location=location, model=model, evaluation=evaluation, - ) - actual = ModelServiceClient.model_evaluation_path( - project, location, model, evaluation - ) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) assert expected == actual def test_parse_model_evaluation_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "model": "cuttlefish", - "evaluation": "mussel", + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", + } path = ModelServiceClient.model_evaluation_path(**expected) @@ -3720,7 +3946,6 @@ def test_parse_model_evaluation_path(): actual = ModelServiceClient.parse_model_evaluation_path(path) assert expected == actual - def test_model_evaluation_slice_path(): project = "winkle" location = "nautilus" @@ -3728,26 +3953,19 @@ def test_model_evaluation_slice_path(): evaluation = "abalone" slice = "squid" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( - project=project, - location=location, - model=model, - evaluation=evaluation, - slice=slice, - ) - actual = ModelServiceClient.model_evaluation_slice_path( - project, location, model, evaluation, slice - ) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) assert expected == actual def test_parse_model_evaluation_slice_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - "evaluation": "oyster", - "slice": "nudibranch", + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", + } path = ModelServiceClient.model_evaluation_slice_path(**expected) @@ -3755,26 +3973,22 @@ def test_parse_model_evaluation_slice_path(): actual = ModelServiceClient.parse_model_evaluation_slice_path(path) assert expected == actual - def test_training_pipeline_path(): project = "cuttlefish" location = "mussel" training_pipeline = "winkle" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = ModelServiceClient.training_pipeline_path( - project, location, training_pipeline - ) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "nautilus", - "location": "scallop", - "training_pipeline": "abalone", + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", + } path = ModelServiceClient.training_pipeline_path(**expected) @@ -3782,20 +3996,18 @@ def test_parse_training_pipeline_path(): actual = ModelServiceClient.parse_training_pipeline_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "clam", + } path = ModelServiceClient.common_billing_account_path(**expected) @@ -3803,18 +4015,18 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "octopus", + } path = ModelServiceClient.common_folder_path(**expected) @@ -3822,18 +4034,18 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nudibranch", + } path = ModelServiceClient.common_organization_path(**expected) @@ -3841,18 +4053,18 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = ModelServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "mussel", + } path = ModelServiceClient.common_project_path(**expected) @@ -3860,22 +4072,20 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "scallop", + "location": "abalone", + } path = ModelServiceClient.common_location_path(**expected) @@ -3887,19 +4097,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index 21e6d0d44f..de2ff38ef2 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -35,9 +35,7 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.pipeline_service import ( - PipelineServiceAsyncClient, -) +from google.cloud.aiplatform_v1.services.pipeline_service import PipelineServiceAsyncClient from google.cloud.aiplatform_v1.services.pipeline_service import PipelineServiceClient from google.cloud.aiplatform_v1.services.pipeline_service import pagers from google.cloud.aiplatform_v1.services.pipeline_service import transports @@ -68,11 +66,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -83,52 +77,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + PipelineServiceClient, + PipelineServiceAsyncClient, +]) def test_pipeline_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + PipelineServiceClient, + PipelineServiceAsyncClient, +]) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -138,7 +116,7 @@ def test_pipeline_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_pipeline_service_client_get_transport_class(): @@ -152,44 +130,29 @@ def test_pipeline_service_client_get_transport_class(): assert transport == transports.PipelineServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) -def test_pipeline_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) +def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -205,7 +168,7 @@ def test_pipeline_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -221,7 +184,7 @@ def test_pipeline_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -241,15 +204,13 @@ def test_pipeline_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -262,62 +223,26 @@ def test_pipeline_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "true", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "false", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -340,18 +265,10 @@ def test_pipeline_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -372,14 +289,9 @@ def test_pipeline_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -393,23 +305,16 @@ def test_pipeline_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -422,24 +327,16 @@ def test_pipeline_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -454,12 +351,10 @@ def test_pipeline_service_client_client_options_credentials_file( def test_pipeline_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = PipelineServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -472,11 +367,10 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest -): +def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -485,14 +379,18 @@ def test_create_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) response = client.create_training_pipeline(request) @@ -507,11 +405,11 @@ def test_create_training_pipeline( assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -524,27 +422,25 @@ def test_create_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: client.create_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CreateTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_create_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.CreateTrainingPipelineRequest, -): +async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CreateTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -553,17 +449,15 @@ async def test_create_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) response = await client.create_training_pipeline(request) @@ -576,11 +470,11 @@ async def test_create_training_pipeline_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -591,17 +485,19 @@ async def test_create_training_pipeline_async_from_dict(): def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: call.return_value = gca_training_pipeline.TrainingPipeline() client.create_training_pipeline(request) @@ -613,25 +509,28 @@ def test_create_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) + type(client.transport.create_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) await client.create_training_pipeline(request) @@ -642,24 +541,29 @@ async def test_create_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -667,45 +571,45 @@ def test_create_training_pipeline_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -713,32 +617,31 @@ async def test_create_training_pipeline_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') @pytest.mark.asyncio async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) -def test_get_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest -): +def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -747,14 +650,18 @@ def test_get_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) response = client.get_training_pipeline(request) @@ -769,11 +676,11 @@ def test_get_training_pipeline( assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -786,27 +693,25 @@ def test_get_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: client.get_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.GetTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_get_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.GetTrainingPipelineRequest, -): +async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.GetTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -815,17 +720,15 @@ async def test_get_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) response = await client.get_training_pipeline(request) @@ -838,11 +741,11 @@ async def test_get_training_pipeline_async( # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -853,17 +756,19 @@ async def test_get_training_pipeline_async_from_dict(): def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: call.return_value = training_pipeline.TrainingPipeline() client.get_training_pipeline(request) @@ -875,25 +780,28 @@ def test_get_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) + type(client.transport.get_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) await client.get_training_pipeline(request) @@ -904,85 +812,99 @@ async def test_get_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_training_pipeline(name="name_value",) + client.get_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_training_pipeline(name="name_value",) + response = await client.get_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', ) -def test_list_training_pipelines( - transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest -): +def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -991,11 +913,12 @@ def test_list_training_pipelines( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_training_pipelines(request) @@ -1010,7 +933,7 @@ def test_list_training_pipelines( assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_training_pipelines_from_dict(): @@ -1021,27 +944,25 @@ def test_list_training_pipelines_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: client.list_training_pipelines() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.ListTrainingPipelinesRequest() - @pytest.mark.asyncio -async def test_list_training_pipelines_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.ListTrainingPipelinesRequest, -): +async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.ListTrainingPipelinesRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1050,14 +971,12 @@ async def test_list_training_pipelines_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( + next_page_token='next_page_token_value', + )) response = await client.list_training_pipelines(request) @@ -1070,7 +989,7 @@ async def test_list_training_pipelines_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1079,17 +998,19 @@ async def test_list_training_pipelines_async_from_dict(): def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: call.return_value = pipeline_service.ListTrainingPipelinesResponse() client.list_training_pipelines(request) @@ -1101,25 +1022,28 @@ def test_list_training_pipelines_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) + type(client.transport.list_training_pipelines), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) await client.list_training_pipelines(request) @@ -1130,87 +1054,104 @@ async def test_list_training_pipelines_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_training_pipelines_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_training_pipelines(parent="parent_value",) + client.list_training_pipelines( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_training_pipelines(parent="parent_value",) + response = await client.list_training_pipelines( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', ) def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1219,14 +1160,17 @@ def test_list_training_pipelines_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1239,7 +1183,9 @@ def test_list_training_pipelines_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_training_pipelines(request={}) @@ -1247,16 +1193,18 @@ def test_list_training_pipelines_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) - + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in results) def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1265,14 +1213,17 @@ def test_list_training_pipelines_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1283,20 +1234,19 @@ def test_list_training_pipelines_pages(): RuntimeError, ) pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1305,14 +1255,17 @@ async def test_list_training_pipelines_async_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1323,25 +1276,25 @@ async def test_list_training_pipelines_async_pager(): RuntimeError, ) async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) - + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in responses) @pytest.mark.asyncio async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1350,14 +1303,17 @@ async def test_list_training_pipelines_async_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1370,15 +1326,14 @@ async def test_list_training_pipelines_async_pages(): pages = [] async for page_ in (await client.list_training_pipelines(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest -): +def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1387,10 +1342,10 @@ def test_delete_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_training_pipeline(request) @@ -1412,27 +1367,25 @@ def test_delete_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: client.delete_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_delete_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.DeleteTrainingPipelineRequest, -): +async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.DeleteTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1441,11 +1394,11 @@ async def test_delete_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_training_pipeline(request) @@ -1466,18 +1419,20 @@ async def test_delete_training_pipeline_async_from_dict(): def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_training_pipeline(request) @@ -1488,25 +1443,28 @@ def test_delete_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_training_pipeline(request) @@ -1517,85 +1475,101 @@ async def test_delete_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_training_pipeline(name="name_value",) + client.delete_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_training_pipeline(name="name_value",) + response = await client.delete_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', ) -def test_cancel_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest -): +def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1604,8 +1578,8 @@ def test_cancel_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1629,27 +1603,25 @@ def test_cancel_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: client.cancel_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CancelTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_cancel_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.CancelTrainingPipelineRequest, -): +async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CancelTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1658,8 +1630,8 @@ async def test_cancel_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1681,17 +1653,19 @@ async def test_cancel_training_pipeline_async_from_dict(): def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: call.return_value = None client.cancel_training_pipeline(request) @@ -1703,22 +1677,27 @@ def test_cancel_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_training_pipeline(request) @@ -1730,75 +1709,92 @@ async def test_cancel_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_training_pipeline(name="name_value",) + client.cancel_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_training_pipeline(name="name_value",) + response = await client.cancel_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', ) @@ -1809,7 +1805,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1828,7 +1825,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1856,16 +1854,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1873,8 +1868,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PipelineServiceGrpcTransport, + ) def test_pipeline_service_base_transport_error(): @@ -1882,15 +1882,13 @@ def test_pipeline_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_pipeline_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1899,12 +1897,12 @@ def test_pipeline_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_training_pipeline", - "get_training_pipeline", - "list_training_pipelines", - "delete_training_pipeline", - "cancel_training_pipeline", - ) + 'create_training_pipeline', + 'get_training_pipeline', + 'list_training_pipelines', + 'delete_training_pipeline', + 'cancel_training_pipeline', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1917,28 +1915,23 @@ def test_pipeline_service_base_transport(): def test_pipeline_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_pipeline_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport() @@ -1947,11 +1940,11 @@ def test_pipeline_service_base_transport_with_adc(): def test_pipeline_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PipelineServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -1959,25 +1952,19 @@ def test_pipeline_service_auth_adc(): def test_pipeline_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1986,13 +1973,15 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_c transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2007,40 +1996,38 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_c with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_pipeline_service_host_with_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_pipeline_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2048,11 +2035,12 @@ def test_pipeline_service_grpc_transport_channel(): def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2061,22 +2049,12 @@ def test_pipeline_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2085,7 +2063,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2101,7 +2079,9 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2115,23 +2095,17 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +def test_pipeline_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2148,7 +2122,9 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2161,12 +2137,16 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): def test_pipeline_service_grpc_lro_client(): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2174,12 +2154,16 @@ def test_pipeline_service_grpc_lro_client(): def test_pipeline_service_grpc_lro_async_client(): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2190,18 +2174,17 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = PipelineServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = PipelineServiceClient.endpoint_path(**expected) @@ -2209,24 +2192,22 @@ def test_parse_endpoint_path(): actual = PipelineServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = PipelineServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = PipelineServiceClient.model_path(**expected) @@ -2234,26 +2215,22 @@ def test_parse_model_path(): actual = PipelineServiceClient.parse_model_path(path) assert expected == actual - def test_training_pipeline_path(): project = "squid" location = "clam" training_pipeline = "whelk" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = PipelineServiceClient.training_pipeline_path( - project, location, training_pipeline - ) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", + } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2261,20 +2238,18 @@ def test_parse_training_pipeline_path(): actual = PipelineServiceClient.parse_training_pipeline_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = PipelineServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "mussel", + } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2282,18 +2257,18 @@ def test_parse_common_billing_account_path(): actual = PipelineServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = PipelineServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "nautilus", + } path = PipelineServiceClient.common_folder_path(**expected) @@ -2301,18 +2276,18 @@ def test_parse_common_folder_path(): actual = PipelineServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = PipelineServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "abalone", + } path = PipelineServiceClient.common_organization_path(**expected) @@ -2320,18 +2295,18 @@ def test_parse_common_organization_path(): actual = PipelineServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = PipelineServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "clam", + } path = PipelineServiceClient.common_project_path(**expected) @@ -2339,22 +2314,20 @@ def test_parse_common_project_path(): actual = PipelineServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = PipelineServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "oyster", + "location": "nudibranch", + } path = PipelineServiceClient.common_location_path(**expected) @@ -2366,19 +2339,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: transport_class = PipelineServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py index d5099832f0..4017a16cc3 100644 --- a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.specialist_pool_service import ( - SpecialistPoolServiceAsyncClient, -) -from google.cloud.aiplatform_v1.services.specialist_pool_service import ( - SpecialistPoolServiceClient, -) +from google.cloud.aiplatform_v1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient +from google.cloud.aiplatform_v1.services.specialist_pool_service import SpecialistPoolServiceClient from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1.services.specialist_pool_service import transports from google.cloud.aiplatform_v1.types import operation as gca_operation @@ -60,11 +56,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -75,53 +67,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + SpecialistPoolServiceClient, + SpecialistPoolServiceAsyncClient, +]) def test_specialist_pool_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + SpecialistPoolServiceClient, + SpecialistPoolServiceAsyncClient, +]) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -131,7 +106,7 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_specialist_pool_service_client_get_transport_class(): @@ -145,48 +120,29 @@ def test_specialist_pool_service_client_get_transport_class(): assert transport == transports.SpecialistPoolServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) -def test_specialist_pool_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) +def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -202,7 +158,7 @@ def test_specialist_pool_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -218,7 +174,7 @@ def test_specialist_pool_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -238,15 +194,13 @@ def test_specialist_pool_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -259,62 +213,26 @@ def test_specialist_pool_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "true", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "false", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -337,18 +255,10 @@ def test_specialist_pool_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -369,14 +279,9 @@ def test_specialist_pool_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -390,27 +295,16 @@ def test_specialist_pool_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -423,28 +317,16 @@ def test_specialist_pool_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -459,12 +341,10 @@ def test_specialist_pool_service_client_client_options_credentials_file( def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = SpecialistPoolServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -477,12 +357,10 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): +def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -491,10 +369,10 @@ def test_create_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_specialist_pool(request) @@ -516,27 +394,25 @@ def test_create_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: client.create_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_create_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): +async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.CreateSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -545,11 +421,11 @@ async def test_create_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_specialist_pool(request) @@ -577,13 +453,13 @@ def test_create_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_specialist_pool(request) @@ -594,7 +470,10 @@ def test_create_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -606,15 +485,13 @@ async def test_create_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_specialist_pool(request) @@ -625,7 +502,10 @@ async def test_create_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_specialist_pool_flattened(): @@ -635,16 +515,16 @@ def test_create_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -652,11 +532,9 @@ def test_create_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') def test_create_specialist_pool_flattened_error(): @@ -669,8 +547,8 @@ def test_create_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) @@ -682,19 +560,19 @@ async def test_create_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -702,11 +580,9 @@ async def test_create_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') @pytest.mark.asyncio @@ -720,17 +596,15 @@ async def test_create_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) -def test_get_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): +def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -739,15 +613,20 @@ def test_get_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + + specialist_manager_emails=['specialist_manager_emails_value'], + + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + ) response = client.get_specialist_pool(request) @@ -762,15 +641,15 @@ def test_get_specialist_pool( assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] def test_get_specialist_pool_from_dict(): @@ -781,27 +660,25 @@ def test_get_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: client.get_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_get_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): +async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.GetSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -810,18 +687,16 @@ async def test_get_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", - specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( + name='name_value', + display_name='display_name_value', + specialist_managers_count=2662, + specialist_manager_emails=['specialist_manager_emails_value'], + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + )) response = await client.get_specialist_pool(request) @@ -834,15 +709,15 @@ async def test_get_specialist_pool_async( # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] @pytest.mark.asyncio @@ -858,12 +733,12 @@ def test_get_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: call.return_value = specialist_pool.SpecialistPool() client.get_specialist_pool(request) @@ -875,7 +750,10 @@ def test_get_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -887,15 +765,13 @@ async def test_get_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) + type(client.transport.get_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) await client.get_specialist_pool(request) @@ -906,7 +782,10 @@ async def test_get_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_specialist_pool_flattened(): @@ -916,21 +795,23 @@ def test_get_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_specialist_pool(name="name_value",) + client.get_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_specialist_pool_flattened_error(): @@ -942,7 +823,8 @@ def test_get_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', ) @@ -954,24 +836,24 @@ async def test_get_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_specialist_pool(name="name_value",) + response = await client.get_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio @@ -984,16 +866,15 @@ async def test_get_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', ) -def test_list_specialist_pools( - transport: str = "grpc", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): +def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1002,11 +883,12 @@ def test_list_specialist_pools( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_specialist_pools(request) @@ -1021,7 +903,7 @@ def test_list_specialist_pools( assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_specialist_pools_from_dict(): @@ -1032,27 +914,25 @@ def test_list_specialist_pools_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: client.list_specialist_pools() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() - @pytest.mark.asyncio -async def test_list_specialist_pools_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): +async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.ListSpecialistPoolsRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1061,14 +941,12 @@ async def test_list_specialist_pools_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_specialist_pools(request) @@ -1081,7 +959,7 @@ async def test_list_specialist_pools_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1097,12 +975,12 @@ def test_list_specialist_pools_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() client.list_specialist_pools(request) @@ -1114,7 +992,10 @@ def test_list_specialist_pools_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1126,15 +1007,13 @@ async def test_list_specialist_pools_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) + type(client.transport.list_specialist_pools), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) await client.list_specialist_pools(request) @@ -1145,7 +1024,10 @@ async def test_list_specialist_pools_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_specialist_pools_flattened(): @@ -1155,21 +1037,23 @@ def test_list_specialist_pools_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_specialist_pools(parent="parent_value",) + client.list_specialist_pools( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_specialist_pools_flattened_error(): @@ -1181,7 +1065,8 @@ def test_list_specialist_pools_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', ) @@ -1193,24 +1078,24 @@ async def test_list_specialist_pools_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_specialist_pools(parent="parent_value",) + response = await client.list_specialist_pools( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio @@ -1223,17 +1108,20 @@ async def test_list_specialist_pools_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', ) def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1242,14 +1130,17 @@ def test_list_specialist_pools_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1262,7 +1153,9 @@ def test_list_specialist_pools_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_specialist_pools(request={}) @@ -1270,16 +1163,18 @@ def test_list_specialist_pools_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) - + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in results) def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1288,14 +1183,17 @@ def test_list_specialist_pools_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1306,10 +1204,9 @@ def test_list_specialist_pools_pages(): RuntimeError, ) pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_specialist_pools_async_pager(): client = SpecialistPoolServiceAsyncClient( @@ -1318,10 +1215,8 @@ async def test_list_specialist_pools_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1330,14 +1225,17 @@ async def test_list_specialist_pools_async_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1348,14 +1246,14 @@ async def test_list_specialist_pools_async_pager(): RuntimeError, ) async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) - + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in responses) @pytest.mark.asyncio async def test_list_specialist_pools_async_pages(): @@ -1365,10 +1263,8 @@ async def test_list_specialist_pools_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1377,14 +1273,17 @@ async def test_list_specialist_pools_async_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1397,16 +1296,14 @@ async def test_list_specialist_pools_async_pages(): pages = [] async for page_ in (await client.list_specialist_pools(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): +def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1415,10 +1312,10 @@ def test_delete_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_specialist_pool(request) @@ -1440,27 +1337,25 @@ def test_delete_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: client.delete_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_delete_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): +async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1469,11 +1364,11 @@ async def test_delete_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_specialist_pool(request) @@ -1501,13 +1396,13 @@ def test_delete_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_specialist_pool(request) @@ -1518,7 +1413,10 @@ def test_delete_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1530,15 +1428,13 @@ async def test_delete_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_specialist_pool(request) @@ -1549,7 +1445,10 @@ async def test_delete_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_specialist_pool_flattened(): @@ -1559,21 +1458,23 @@ def test_delete_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_specialist_pool(name="name_value",) + client.delete_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_specialist_pool_flattened_error(): @@ -1585,7 +1486,8 @@ def test_delete_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', ) @@ -1597,24 +1499,26 @@ async def test_delete_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_specialist_pool(name="name_value",) + response = await client.delete_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio @@ -1627,16 +1531,15 @@ async def test_delete_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', ) -def test_update_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): +def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1645,10 +1548,10 @@ def test_update_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.update_specialist_pool(request) @@ -1670,27 +1573,25 @@ def test_update_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: client.update_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_update_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): +async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1699,11 +1600,11 @@ async def test_update_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.update_specialist_pool(request) @@ -1731,13 +1632,13 @@ def test_update_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" + request.specialist_pool.name = 'specialist_pool.name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.update_specialist_pool(request) @@ -1749,9 +1650,9 @@ def test_update_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1763,15 +1664,13 @@ async def test_update_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" + request.specialist_pool.name = 'specialist_pool.name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.update_specialist_pool(request) @@ -1783,9 +1682,9 @@ async def test_update_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] def test_update_specialist_pool_flattened(): @@ -1795,16 +1694,16 @@ def test_update_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1812,11 +1711,9 @@ def test_update_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_specialist_pool_flattened_error(): @@ -1829,8 +1726,8 @@ def test_update_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @@ -1842,19 +1739,19 @@ async def test_update_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1862,11 +1759,9 @@ async def test_update_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio @@ -1880,8 +1775,8 @@ async def test_update_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @@ -1892,7 +1787,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1911,7 +1807,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1939,16 +1836,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1959,7 +1853,10 @@ def test_transport_grpc_default(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), ) - assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) + assert isinstance( + client.transport, + transports.SpecialistPoolServiceGrpcTransport, + ) def test_specialist_pool_service_base_transport_error(): @@ -1967,15 +1864,13 @@ def test_specialist_pool_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_specialist_pool_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1984,12 +1879,12 @@ def test_specialist_pool_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_specialist_pool", - "get_specialist_pool", - "list_specialist_pools", - "delete_specialist_pool", - "update_specialist_pool", - ) + 'create_specialist_pool', + 'get_specialist_pool', + 'list_specialist_pools', + 'delete_specialist_pool', + 'update_specialist_pool', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2002,28 +1897,23 @@ def test_specialist_pool_service_base_transport(): def test_specialist_pool_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_specialist_pool_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport() @@ -2032,11 +1922,11 @@ def test_specialist_pool_service_base_transport_with_adc(): def test_specialist_pool_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) SpecialistPoolServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -2044,26 +1934,18 @@ def test_specialist_pool_service_auth_adc(): def test_specialist_pool_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( - transport_class, + transport_class ): cred = credentials.AnonymousCredentials() @@ -2073,13 +1955,15 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2094,40 +1978,38 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_specialist_pool_service_host_with_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2135,11 +2017,12 @@ def test_specialist_pool_service_grpc_transport_channel(): def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2148,22 +2031,12 @@ def test_specialist_pool_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2172,7 +2045,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2188,7 +2061,9 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2202,23 +2077,17 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) -def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +def test_specialist_pool_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2235,7 +2104,9 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2248,12 +2119,16 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class def test_specialist_pool_service_grpc_lro_client(): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2261,12 +2136,16 @@ def test_specialist_pool_service_grpc_lro_client(): def test_specialist_pool_service_grpc_lro_async_client(): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2277,20 +2156,17 @@ def test_specialist_pool_path(): location = "clam" specialist_pool = "whelk" - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) - actual = SpecialistPoolServiceClient.specialist_pool_path( - project, location, specialist_pool - ) + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) assert expected == actual def test_parse_specialist_pool_path(): expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", + } path = SpecialistPoolServiceClient.specialist_pool_path(**expected) @@ -2298,20 +2174,18 @@ def test_parse_specialist_pool_path(): actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "mussel", + } path = SpecialistPoolServiceClient.common_billing_account_path(**expected) @@ -2319,18 +2193,18 @@ def test_parse_common_billing_account_path(): actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = SpecialistPoolServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "nautilus", + } path = SpecialistPoolServiceClient.common_folder_path(**expected) @@ -2338,18 +2212,18 @@ def test_parse_common_folder_path(): actual = SpecialistPoolServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = SpecialistPoolServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "abalone", + } path = SpecialistPoolServiceClient.common_organization_path(**expected) @@ -2357,18 +2231,18 @@ def test_parse_common_organization_path(): actual = SpecialistPoolServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = SpecialistPoolServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "clam", + } path = SpecialistPoolServiceClient.common_project_path(**expected) @@ -2376,22 +2250,20 @@ def test_parse_common_project_path(): actual = SpecialistPoolServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = SpecialistPoolServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "oyster", + "location": "nudibranch", + } path = SpecialistPoolServiceClient.common_location_path(**expected) @@ -2403,19 +2275,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: transport_class = SpecialistPoolServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/__init__.py b/tests/unit/gapic/aiplatform_v1beta1/__init__.py index 42ffdf2bc4..6a73015364 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/__init__.py +++ b/tests/unit/gapic/aiplatform_v1beta1/__init__.py @@ -1,3 +1,4 @@ + # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index 6042fa6f42..eb48bd6ebb 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceClient from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.services.dataset_service import transports from google.cloud.aiplatform_v1beta1.types import annotation @@ -67,11 +63,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -82,52 +74,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + DatasetServiceClient, + DatasetServiceAsyncClient, +]) def test_dataset_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + DatasetServiceClient, + DatasetServiceAsyncClient, +]) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -137,7 +113,7 @@ def test_dataset_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_dataset_service_client_get_transport_class(): @@ -151,44 +127,29 @@ def test_dataset_service_client_get_transport_class(): assert transport == transports.DatasetServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) -def test_dataset_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) +def test_dataset_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -204,7 +165,7 @@ def test_dataset_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -220,7 +181,7 @@ def test_dataset_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -240,15 +201,13 @@ def test_dataset_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -261,52 +220,26 @@ def test_dataset_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -329,18 +262,10 @@ def test_dataset_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -361,14 +286,9 @@ def test_dataset_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -382,23 +302,16 @@ def test_dataset_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -411,24 +324,16 @@ def test_dataset_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -443,12 +348,10 @@ def test_dataset_service_client_client_options_credentials_file( def test_dataset_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = DatasetServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -461,11 +364,10 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset( - transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest -): +def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -473,9 +375,11 @@ def test_create_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_dataset(request) @@ -497,24 +401,25 @@ def test_create_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: client.create_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.CreateDatasetRequest() - @pytest.mark.asyncio -async def test_create_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest -): +async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.CreateDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -522,10 +427,12 @@ async def test_create_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_dataset(request) @@ -546,16 +453,20 @@ async def test_create_dataset_async_from_dict(): def test_create_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_dataset(request) @@ -566,23 +477,28 @@ def test_create_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_dataset(request) @@ -593,21 +509,29 @@ async def test_create_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -615,40 +539,47 @@ def test_create_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') def test_create_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_dataset( dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) @pytest.mark.asyncio async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -656,30 +587,31 @@ async def test_create_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') @pytest.mark.asyncio async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_dataset( dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) -def test_get_dataset( - transport: str = "grpc", request_type=dataset_service.GetDatasetRequest -): +def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -687,13 +619,19 @@ def test_get_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + ) response = client.get_dataset(request) @@ -708,13 +646,13 @@ def test_get_dataset( assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_dataset_from_dict(): @@ -725,24 +663,25 @@ def test_get_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: client.get_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetDatasetRequest() - @pytest.mark.asyncio -async def test_get_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest -): +async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -750,16 +689,16 @@ async def test_get_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) response = await client.get_dataset(request) @@ -772,13 +711,13 @@ async def test_get_dataset_async( # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -787,15 +726,19 @@ async def test_get_dataset_async_from_dict(): def test_get_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: call.return_value = dataset.Dataset() client.get_dataset(request) @@ -807,20 +750,27 @@ def test_get_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) await client.get_dataset(request) @@ -832,79 +782,99 @@ async def test_get_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset(name="name_value",) + client.get_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", + dataset_service.GetDatasetRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset(name="name_value",) + response = await client.get_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", + dataset_service.GetDatasetRequest(), + name='name_value', ) -def test_update_dataset( - transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest -): +def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -912,13 +882,19 @@ def test_update_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + ) response = client.update_dataset(request) @@ -933,13 +909,13 @@ def test_update_dataset( assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_dataset_from_dict(): @@ -950,24 +926,25 @@ def test_update_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: client.update_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.UpdateDatasetRequest() - @pytest.mark.asyncio -async def test_update_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest -): +async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.UpdateDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -975,16 +952,16 @@ async def test_update_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) response = await client.update_dataset(request) @@ -997,13 +974,13 @@ async def test_update_dataset_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -1012,15 +989,19 @@ async def test_update_dataset_async_from_dict(): def test_update_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" + request.dataset.name = 'dataset.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: call.return_value = gca_dataset.Dataset() client.update_dataset(request) @@ -1032,22 +1013,27 @@ def test_update_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" + request.dataset.name = 'dataset.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) await client.update_dataset(request) @@ -1059,24 +1045,29 @@ async def test_update_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] def test_update_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1084,30 +1075,36 @@ def test_update_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() @@ -1115,8 +1112,8 @@ async def test_update_dataset_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1124,30 +1121,31 @@ async def test_update_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_list_datasets( - transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest -): +def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1155,10 +1153,13 @@ def test_list_datasets( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_datasets(request) @@ -1173,7 +1174,7 @@ def test_list_datasets( assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_datasets_from_dict(): @@ -1184,24 +1185,25 @@ def test_list_datasets_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: client.list_datasets() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDatasetsRequest() - @pytest.mark.asyncio -async def test_list_datasets_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest -): +async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDatasetsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1209,13 +1211,13 @@ async def test_list_datasets_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_datasets(request) @@ -1228,7 +1230,7 @@ async def test_list_datasets_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1237,15 +1239,19 @@ async def test_list_datasets_async_from_dict(): def test_list_datasets_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: call.return_value = dataset_service.ListDatasetsResponse() client.list_datasets(request) @@ -1257,23 +1263,28 @@ def test_list_datasets_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) await client.list_datasets(request) @@ -1284,100 +1295,138 @@ async def test_list_datasets_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_datasets_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_datasets(parent="parent_value",) + client.list_datasets( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_datasets_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", + dataset_service.ListDatasetsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_datasets(parent="parent_value",) + response = await client.list_datasets( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", + dataset_service.ListDatasetsRequest(), + parent='parent_value', ) def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_datasets(request={}) @@ -1385,102 +1434,147 @@ def test_list_datasets_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) for i in results) - + assert all(isinstance(i, dataset.Dataset) + for i in results) def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) for i in responses) - + assert all(isinstance(i, dataset.Dataset) + for i in responses) @pytest.mark.asyncio async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_datasets(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_dataset( - transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest -): +def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1488,9 +1582,11 @@ def test_delete_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_dataset(request) @@ -1512,24 +1608,25 @@ def test_delete_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: client.delete_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.DeleteDatasetRequest() - @pytest.mark.asyncio -async def test_delete_dataset_async( - transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest -): +async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.DeleteDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1537,10 +1634,12 @@ async def test_delete_dataset_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_dataset(request) @@ -1561,16 +1660,20 @@ async def test_delete_dataset_async_from_dict(): def test_delete_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_dataset(request) @@ -1581,23 +1684,28 @@ def test_delete_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_dataset(request) @@ -1608,81 +1716,101 @@ async def test_delete_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset(name="name_value",) + client.delete_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", + dataset_service.DeleteDatasetRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset(name="name_value",) + response = await client.delete_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", + dataset_service.DeleteDatasetRequest(), + name='name_value', ) -def test_import_data( - transport: str = "grpc", request_type=dataset_service.ImportDataRequest -): +def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1690,9 +1818,11 @@ def test_import_data( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.import_data(request) @@ -1714,24 +1844,25 @@ def test_import_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: client.import_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ImportDataRequest() - @pytest.mark.asyncio -async def test_import_data_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest -): +async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ImportDataRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1739,10 +1870,12 @@ async def test_import_data_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.import_data(request) @@ -1763,16 +1896,20 @@ async def test_import_data_async_from_dict(): def test_import_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.import_data(request) @@ -1783,23 +1920,28 @@ def test_import_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.import_data(request) @@ -1810,24 +1952,29 @@ async def test_import_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_import_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) # Establish that the underlying call was made with the expected @@ -1835,47 +1982,47 @@ def test_import_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] def test_import_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_data( dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) @pytest.mark.asyncio async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) # Establish that the underlying call was made with the expected @@ -1883,34 +2030,31 @@ async def test_import_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] @pytest.mark.asyncio async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.import_data( dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) -def test_export_data( - transport: str = "grpc", request_type=dataset_service.ExportDataRequest -): +def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1918,9 +2062,11 @@ def test_export_data( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.export_data(request) @@ -1942,24 +2088,25 @@ def test_export_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: client.export_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ExportDataRequest() - @pytest.mark.asyncio -async def test_export_data_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest -): +async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ExportDataRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1967,10 +2114,12 @@ async def test_export_data_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.export_data(request) @@ -1991,16 +2140,20 @@ async def test_export_data_async_from_dict(): def test_export_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.export_data(request) @@ -2011,23 +2164,28 @@ def test_export_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.export_data(request) @@ -2038,26 +2196,29 @@ async def test_export_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_export_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) # Establish that the underlying call was made with the expected @@ -2065,53 +2226,47 @@ def test_export_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) def test_export_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_data( dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) @pytest.mark.asyncio async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) # Establish that the underlying call was made with the expected @@ -2119,38 +2274,31 @@ async def test_export_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) @pytest.mark.asyncio async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_data( dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) -def test_list_data_items( - transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest -): +def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2158,10 +2306,13 @@ def test_list_data_items( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_data_items(request) @@ -2176,7 +2327,7 @@ def test_list_data_items( assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_data_items_from_dict(): @@ -2187,24 +2338,25 @@ def test_list_data_items_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: client.list_data_items() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDataItemsRequest() - @pytest.mark.asyncio -async def test_list_data_items_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest -): +async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDataItemsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2212,13 +2364,13 @@ async def test_list_data_items_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_data_items(request) @@ -2231,7 +2383,7 @@ async def test_list_data_items_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2240,15 +2392,19 @@ async def test_list_data_items_async_from_dict(): def test_list_data_items_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: call.return_value = dataset_service.ListDataItemsResponse() client.list_data_items(request) @@ -2260,23 +2416,28 @@ def test_list_data_items_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) await client.list_data_items(request) @@ -2287,81 +2448,104 @@ async def test_list_data_items_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_data_items_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items(parent="parent_value",) + client.list_data_items( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_data_items_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", + dataset_service.ListDataItemsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items(parent="parent_value",) + response = await client.list_data_items( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", + dataset_service.ListDataItemsRequest(), + parent='parent_value', ) def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2370,23 +2554,32 @@ def test_list_data_items_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_data_items(request={}) @@ -2394,14 +2587,18 @@ def test_list_data_items_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) for i in results) - + assert all(isinstance(i, data_item.DataItem) + for i in results) def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2410,32 +2607,40 @@ def test_list_data_items_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2444,37 +2649,46 @@ async def test_list_data_items_async_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) for i in responses) - + assert all(isinstance(i, data_item.DataItem) + for i in responses) @pytest.mark.asyncio async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2483,31 +2697,37 @@ async def test_list_data_items_async_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_data_items(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec( - transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest -): +def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2516,11 +2736,16 @@ def test_get_annotation_spec( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + ) response = client.get_annotation_spec(request) @@ -2535,11 +2760,11 @@ def test_get_annotation_spec( assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_annotation_spec_from_dict(): @@ -2550,27 +2775,25 @@ def test_get_annotation_spec_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: client.get_annotation_spec() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetAnnotationSpecRequest() - @pytest.mark.asyncio -async def test_get_annotation_spec_async( - transport: str = "grpc_asyncio", - request_type=dataset_service.GetAnnotationSpecRequest, -): +async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetAnnotationSpecRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2579,14 +2802,14 @@ async def test_get_annotation_spec_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( + name='name_value', + display_name='display_name_value', + etag='etag_value', + )) response = await client.get_annotation_spec(request) @@ -2599,11 +2822,11 @@ async def test_get_annotation_spec_async( # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -2612,17 +2835,19 @@ async def test_get_annotation_spec_async_from_dict(): def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: call.return_value = annotation_spec.AnnotationSpec() client.get_annotation_spec(request) @@ -2634,25 +2859,28 @@ def test_get_annotation_spec_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) + type(client.transport.get_annotation_spec), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) await client.get_annotation_spec(request) @@ -2663,85 +2891,99 @@ async def test_get_annotation_spec_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_annotation_spec_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_annotation_spec(name="name_value",) + client.get_annotation_spec( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", + dataset_service.GetAnnotationSpecRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_annotation_spec(name="name_value",) + response = await client.get_annotation_spec( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", + dataset_service.GetAnnotationSpecRequest(), + name='name_value', ) -def test_list_annotations( - transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest -): +def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2749,10 +2991,13 @@ def test_list_annotations( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_annotations(request) @@ -2767,7 +3012,7 @@ def test_list_annotations( assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_annotations_from_dict(): @@ -2778,24 +3023,25 @@ def test_list_annotations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: client.list_annotations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListAnnotationsRequest() - @pytest.mark.asyncio -async def test_list_annotations_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest -): +async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListAnnotationsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2803,13 +3049,13 @@ async def test_list_annotations_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_annotations(request) @@ -2822,7 +3068,7 @@ async def test_list_annotations_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2831,15 +3077,19 @@ async def test_list_annotations_async_from_dict(): def test_list_annotations_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: call.return_value = dataset_service.ListAnnotationsResponse() client.list_annotations(request) @@ -2851,23 +3101,28 @@ def test_list_annotations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) await client.list_annotations(request) @@ -2878,81 +3133,104 @@ async def test_list_annotations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_annotations_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations(parent="parent_value",) + client.list_annotations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_annotations_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", + dataset_service.ListAnnotationsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations(parent="parent_value",) + response = await client.list_annotations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", + dataset_service.ListAnnotationsRequest(), + parent='parent_value', ) def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2961,23 +3239,32 @@ def test_list_annotations_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_annotations(request={}) @@ -2985,14 +3272,18 @@ def test_list_annotations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) for i in results) - + assert all(isinstance(i, annotation.Annotation) + for i in results) def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3001,32 +3292,40 @@ def test_list_annotations_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3035,37 +3334,46 @@ async def test_list_annotations_async_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) for i in responses) - + assert all(isinstance(i, annotation.Annotation) + for i in responses) @pytest.mark.asyncio async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3074,23 +3382,30 @@ async def test_list_annotations_async_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_annotations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token @@ -3101,7 +3416,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3120,7 +3436,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -3148,16 +3465,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3165,8 +3479,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.DatasetServiceGrpcTransport, + ) def test_dataset_service_base_transport_error(): @@ -3174,15 +3493,13 @@ def test_dataset_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_dataset_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3191,17 +3508,17 @@ def test_dataset_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_dataset", - "get_dataset", - "update_dataset", - "list_datasets", - "delete_dataset", - "import_data", - "export_data", - "list_data_items", - "get_annotation_spec", - "list_annotations", - ) + 'create_dataset', + 'get_dataset', + 'update_dataset', + 'list_datasets', + 'delete_dataset', + 'import_data', + 'export_data', + 'list_data_items', + 'get_annotation_spec', + 'list_annotations', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3214,28 +3531,23 @@ def test_dataset_service_base_transport(): def test_dataset_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_dataset_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport() @@ -3244,11 +3556,11 @@ def test_dataset_service_base_transport_with_adc(): def test_dataset_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) DatasetServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -3256,25 +3568,19 @@ def test_dataset_service_auth_adc(): def test_dataset_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +def test_dataset_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3283,13 +3589,15 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_cl transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3304,40 +3612,38 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_cl with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_dataset_service_host_with_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_dataset_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3345,11 +3651,12 @@ def test_dataset_service_grpc_transport_channel(): def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3358,22 +3665,12 @@ def test_dataset_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3382,7 +3679,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3398,7 +3695,9 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3412,23 +3711,17 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_dataset_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +def test_dataset_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3445,7 +3738,9 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3458,12 +3753,16 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): def test_dataset_service_grpc_lro_client(): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3471,12 +3770,16 @@ def test_dataset_service_grpc_lro_client(): def test_dataset_service_grpc_lro_async_client(): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3489,26 +3792,19 @@ def test_annotation_path(): data_item = "octopus" annotation = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( - project=project, - location=location, - dataset=dataset, - data_item=data_item, - annotation=annotation, - ) - actual = DatasetServiceClient.annotation_path( - project, location, dataset, data_item, annotation - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) assert expected == actual def test_parse_annotation_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - "data_item": "winkle", - "annotation": "nautilus", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", + } path = DatasetServiceClient.annotation_path(**expected) @@ -3516,31 +3812,24 @@ def test_parse_annotation_path(): actual = DatasetServiceClient.parse_annotation_path(path) assert expected == actual - def test_annotation_spec_path(): project = "scallop" location = "abalone" dataset = "squid" annotation_spec = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( - project=project, - location=location, - dataset=dataset, - annotation_spec=annotation_spec, - ) - actual = DatasetServiceClient.annotation_spec_path( - project, location, dataset, annotation_spec - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) assert expected == actual def test_parse_annotation_spec_path(): expected = { - "project": "whelk", - "location": "octopus", - "dataset": "oyster", - "annotation_spec": "nudibranch", + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", + } path = DatasetServiceClient.annotation_spec_path(**expected) @@ -3548,26 +3837,24 @@ def test_parse_annotation_spec_path(): actual = DatasetServiceClient.parse_annotation_spec_path(path) assert expected == actual - def test_data_item_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" data_item = "nautilus" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( - project=project, location=location, dataset=dataset, data_item=data_item, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) assert expected == actual def test_parse_data_item_path(): expected = { - "project": "scallop", - "location": "abalone", - "dataset": "squid", - "data_item": "clam", + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", + } path = DatasetServiceClient.data_item_path(**expected) @@ -3575,24 +3862,22 @@ def test_parse_data_item_path(): actual = DatasetServiceClient.parse_data_item_path(path) assert expected == actual - def test_dataset_path(): project = "whelk" location = "octopus" dataset = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = DatasetServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + } path = DatasetServiceClient.dataset_path(**expected) @@ -3600,20 +3885,18 @@ def test_parse_dataset_path(): actual = DatasetServiceClient.parse_dataset_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = DatasetServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "nautilus", + } path = DatasetServiceClient.common_billing_account_path(**expected) @@ -3621,18 +3904,18 @@ def test_parse_common_billing_account_path(): actual = DatasetServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = DatasetServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "abalone", + } path = DatasetServiceClient.common_folder_path(**expected) @@ -3640,18 +3923,18 @@ def test_parse_common_folder_path(): actual = DatasetServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = DatasetServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "clam", + } path = DatasetServiceClient.common_organization_path(**expected) @@ -3659,18 +3942,18 @@ def test_parse_common_organization_path(): actual = DatasetServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = DatasetServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "octopus", + } path = DatasetServiceClient.common_project_path(**expected) @@ -3678,22 +3961,20 @@ def test_parse_common_project_path(): actual = DatasetServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = DatasetServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "cuttlefish", + "location": "mussel", + } path = DatasetServiceClient.common_location_path(**expected) @@ -3705,19 +3986,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: transport_class = DatasetServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index bda98b26a5..47d80619c5 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceClient from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type @@ -67,11 +63,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -82,52 +74,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + EndpointServiceClient, + EndpointServiceAsyncClient, +]) def test_endpoint_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + EndpointServiceClient, + EndpointServiceAsyncClient, +]) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -137,7 +113,7 @@ def test_endpoint_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_endpoint_service_client_get_transport_class(): @@ -151,44 +127,29 @@ def test_endpoint_service_client_get_transport_class(): assert transport == transports.EndpointServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) -def test_endpoint_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) +def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -204,7 +165,7 @@ def test_endpoint_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -220,7 +181,7 @@ def test_endpoint_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -240,15 +201,13 @@ def test_endpoint_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -261,62 +220,26 @@ def test_endpoint_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "true", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "false", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -339,18 +262,10 @@ def test_endpoint_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -371,14 +286,9 @@ def test_endpoint_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -392,23 +302,16 @@ def test_endpoint_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -421,24 +324,16 @@ def test_endpoint_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -453,12 +348,10 @@ def test_endpoint_service_client_client_options_credentials_file( def test_endpoint_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = EndpointServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -471,11 +364,10 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint( - transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest -): +def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -483,9 +375,11 @@ def test_create_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_endpoint(request) @@ -507,24 +401,25 @@ def test_create_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: client.create_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.CreateEndpointRequest() - @pytest.mark.asyncio -async def test_create_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest -): +async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.CreateEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -532,10 +427,12 @@ async def test_create_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_endpoint(request) @@ -556,16 +453,20 @@ async def test_create_endpoint_async_from_dict(): def test_create_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_endpoint(request) @@ -576,23 +477,28 @@ def test_create_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_endpoint(request) @@ -603,21 +509,29 @@ async def test_create_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -625,40 +539,47 @@ def test_create_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') def test_create_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) @pytest.mark.asyncio async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -666,30 +587,31 @@ async def test_create_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') @pytest.mark.asyncio async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) -def test_get_endpoint( - transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest -): +def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -697,13 +619,19 @@ def test_get_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + ) response = client.get_endpoint(request) @@ -718,13 +646,13 @@ def test_get_endpoint( assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_endpoint_from_dict(): @@ -735,24 +663,25 @@ def test_get_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: client.get_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.GetEndpointRequest() - @pytest.mark.asyncio -async def test_get_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest -): +async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.GetEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -760,16 +689,16 @@ async def test_get_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) response = await client.get_endpoint(request) @@ -782,13 +711,13 @@ async def test_get_endpoint_async( # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -797,15 +726,19 @@ async def test_get_endpoint_async_from_dict(): def test_get_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: call.return_value = endpoint.Endpoint() client.get_endpoint(request) @@ -817,20 +750,27 @@ def test_get_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) await client.get_endpoint(request) @@ -842,79 +782,99 @@ async def test_get_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_endpoint(name="name_value",) + client.get_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", + endpoint_service.GetEndpointRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_endpoint(name="name_value",) + response = await client.get_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", + endpoint_service.GetEndpointRequest(), + name='name_value', ) -def test_list_endpoints( - transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest -): +def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -922,10 +882,13 @@ def test_list_endpoints( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_endpoints(request) @@ -940,7 +903,7 @@ def test_list_endpoints( assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_endpoints_from_dict(): @@ -951,24 +914,25 @@ def test_list_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: client.list_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.ListEndpointsRequest() - @pytest.mark.asyncio -async def test_list_endpoints_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest -): +async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.ListEndpointsRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -976,13 +940,13 @@ async def test_list_endpoints_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_endpoints(request) @@ -995,7 +959,7 @@ async def test_list_endpoints_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1004,15 +968,19 @@ async def test_list_endpoints_async_from_dict(): def test_list_endpoints_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: call.return_value = endpoint_service.ListEndpointsResponse() client.list_endpoints(request) @@ -1024,23 +992,28 @@ def test_list_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) await client.list_endpoints(request) @@ -1051,81 +1024,104 @@ async def test_list_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_endpoints_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_endpoints(parent="parent_value",) + client.list_endpoints( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_endpoints_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", + endpoint_service.ListEndpointsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_endpoints(parent="parent_value",) + response = await client.list_endpoints( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", + endpoint_service.ListEndpointsRequest(), + parent='parent_value', ) def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1134,23 +1130,32 @@ def test_list_endpoints_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_endpoints(request={}) @@ -1158,14 +1163,18 @@ def test_list_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in results) - + assert all(isinstance(i, endpoint.Endpoint) + for i in results) def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1174,32 +1183,40 @@ def test_list_endpoints_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1208,37 +1225,46 @@ async def test_list_endpoints_async_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in responses) - + assert all(isinstance(i, endpoint.Endpoint) + for i in responses) @pytest.mark.asyncio async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1247,31 +1273,37 @@ async def test_list_endpoints_async_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_update_endpoint( - transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest -): +def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1279,13 +1311,19 @@ def test_update_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + ) response = client.update_endpoint(request) @@ -1300,13 +1338,13 @@ def test_update_endpoint( assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_endpoint_from_dict(): @@ -1317,24 +1355,25 @@ def test_update_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: client.update_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UpdateEndpointRequest() - @pytest.mark.asyncio -async def test_update_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest -): +async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UpdateEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1342,16 +1381,16 @@ async def test_update_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) response = await client.update_endpoint(request) @@ -1364,13 +1403,13 @@ async def test_update_endpoint_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -1379,15 +1418,19 @@ async def test_update_endpoint_async_from_dict(): def test_update_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" + request.endpoint.name = 'endpoint.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: call.return_value = gca_endpoint.Endpoint() client.update_endpoint(request) @@ -1399,25 +1442,28 @@ def test_update_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" + request.endpoint.name = 'endpoint.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) await client.update_endpoint(request) @@ -1428,24 +1474,29 @@ async def test_update_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] def test_update_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1453,41 +1504,45 @@ def test_update_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1495,30 +1550,31 @@ async def test_update_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_delete_endpoint( - transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest -): +def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1526,9 +1582,11 @@ def test_delete_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_endpoint(request) @@ -1550,24 +1608,25 @@ def test_delete_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: client.delete_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeleteEndpointRequest() - @pytest.mark.asyncio -async def test_delete_endpoint_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest -): +async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeleteEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1575,10 +1634,12 @@ async def test_delete_endpoint_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_endpoint(request) @@ -1599,16 +1660,20 @@ async def test_delete_endpoint_async_from_dict(): def test_delete_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_endpoint(request) @@ -1619,23 +1684,28 @@ def test_delete_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_endpoint(request) @@ -1646,81 +1716,101 @@ async def test_delete_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_endpoint(name="name_value",) + client.delete_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", + endpoint_service.DeleteEndpointRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_endpoint(name="name_value",) + response = await client.delete_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", + endpoint_service.DeleteEndpointRequest(), + name='name_value', ) -def test_deploy_model( - transport: str = "grpc", request_type=endpoint_service.DeployModelRequest -): +def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1728,9 +1818,11 @@ def test_deploy_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.deploy_model(request) @@ -1752,24 +1844,25 @@ def test_deploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: client.deploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeployModelRequest() - @pytest.mark.asyncio -async def test_deploy_model_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest -): +async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeployModelRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1777,10 +1870,12 @@ async def test_deploy_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.deploy_model(request) @@ -1801,16 +1896,20 @@ async def test_deploy_model_async_from_dict(): def test_deploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.deploy_model(request) @@ -1821,23 +1920,28 @@ def test_deploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.deploy_model(request) @@ -1848,29 +1952,30 @@ async def test_deploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] def test_deploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -1878,63 +1983,51 @@ def test_deploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} def test_deploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) @pytest.mark.asyncio async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -1942,45 +2035,34 @@ async def test_deploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} @pytest.mark.asyncio async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) -def test_undeploy_model( - transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest -): +def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1988,9 +2070,11 @@ def test_undeploy_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.undeploy_model(request) @@ -2012,24 +2096,25 @@ def test_undeploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: client.undeploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UndeployModelRequest() - @pytest.mark.asyncio -async def test_undeploy_model_async( - transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest -): +async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UndeployModelRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2037,10 +2122,12 @@ async def test_undeploy_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.undeploy_model(request) @@ -2061,16 +2148,20 @@ async def test_undeploy_model_async_from_dict(): def test_undeploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.undeploy_model(request) @@ -2081,23 +2172,28 @@ def test_undeploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.undeploy_model(request) @@ -2108,23 +2204,30 @@ async def test_undeploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] def test_undeploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -2132,45 +2235,51 @@ def test_undeploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model_id == "deployed_model_id_value" + assert args[0].deployed_model_id == 'deployed_model_id_value' - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} def test_undeploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) @pytest.mark.asyncio async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -2178,25 +2287,27 @@ async def test_undeploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model_id == "deployed_model_id_value" + assert args[0].deployed_model_id == 'deployed_model_id_value' - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} @pytest.mark.asyncio async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) @@ -2207,7 +2318,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2226,7 +2338,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -2254,16 +2367,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2271,8 +2381,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.EndpointServiceGrpcTransport, + ) def test_endpoint_service_base_transport_error(): @@ -2280,15 +2395,13 @@ def test_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2297,14 +2410,14 @@ def test_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_endpoint", - "get_endpoint", - "list_endpoints", - "update_endpoint", - "delete_endpoint", - "deploy_model", - "undeploy_model", - ) + 'create_endpoint', + 'get_endpoint', + 'list_endpoints', + 'update_endpoint', + 'delete_endpoint', + 'deploy_model', + 'undeploy_model', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2317,28 +2430,23 @@ def test_endpoint_service_base_transport(): def test_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport() @@ -2347,11 +2455,11 @@ def test_endpoint_service_base_transport_with_adc(): def test_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) EndpointServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -2359,25 +2467,19 @@ def test_endpoint_service_auth_adc(): def test_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -2386,13 +2488,15 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_c transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2407,40 +2511,38 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_c with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_endpoint_service_host_with_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_endpoint_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2448,11 +2550,12 @@ def test_endpoint_service_grpc_transport_channel(): def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2461,22 +2564,12 @@ def test_endpoint_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2485,7 +2578,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2501,7 +2594,9 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2515,23 +2610,17 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +def test_endpoint_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2548,7 +2637,9 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2561,12 +2652,16 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): def test_endpoint_service_grpc_lro_client(): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2574,12 +2669,16 @@ def test_endpoint_service_grpc_lro_client(): def test_endpoint_service_grpc_lro_async_client(): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2590,18 +2689,17 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = EndpointServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = EndpointServiceClient.endpoint_path(**expected) @@ -2609,24 +2707,22 @@ def test_parse_endpoint_path(): actual = EndpointServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = EndpointServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = EndpointServiceClient.model_path(**expected) @@ -2634,20 +2730,18 @@ def test_parse_model_path(): actual = EndpointServiceClient.parse_model_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = EndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "clam", + } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2655,18 +2749,18 @@ def test_parse_common_billing_account_path(): actual = EndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "octopus", + } path = EndpointServiceClient.common_folder_path(**expected) @@ -2674,18 +2768,18 @@ def test_parse_common_folder_path(): actual = EndpointServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nudibranch", + } path = EndpointServiceClient.common_organization_path(**expected) @@ -2693,18 +2787,18 @@ def test_parse_common_organization_path(): actual = EndpointServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = EndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "mussel", + } path = EndpointServiceClient.common_project_path(**expected) @@ -2712,22 +2806,20 @@ def test_parse_common_project_path(): actual = EndpointServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = EndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "scallop", + "location": "abalone", + } path = EndpointServiceClient.common_location_path(**expected) @@ -2739,19 +2831,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: transport_class = EndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index e230d9f4b8..7593ba87a6 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -41,28 +41,25 @@ from google.cloud.aiplatform_v1beta1.services.job_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import explanation_metadata from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import model_monitoring from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import study from google.longrunning import operations_pb2 @@ -84,11 +81,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -99,49 +92,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert ( - JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) + assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [JobServiceClient, JobServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + JobServiceClient, + JobServiceAsyncClient, +]) def test_job_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [JobServiceClient, JobServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + JobServiceClient, + JobServiceAsyncClient, +]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -151,7 +131,7 @@ def test_job_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_job_service_client_get_transport_class(): @@ -165,42 +145,29 @@ def test_job_service_client_get_transport_class(): assert transport == transports.JobServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) -def test_job_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) +def test_job_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -216,7 +183,7 @@ def test_job_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -232,7 +199,7 @@ def test_job_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -252,15 +219,13 @@ def test_job_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -273,50 +238,26 @@ def test_job_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -339,18 +280,10 @@ def test_job_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -371,14 +304,9 @@ def test_job_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -392,23 +320,16 @@ def test_job_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -421,24 +342,16 @@ def test_job_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -453,11 +366,11 @@ def test_job_service_client_client_options_credentials_file( def test_job_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + client = JobServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -469,11 +382,10 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job( - transport: str = "grpc", request_type=job_service.CreateCustomJobRequest -): +def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -482,13 +394,16 @@ def test_create_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_custom_job(request) @@ -503,9 +418,9 @@ def test_create_custom_job( assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -518,26 +433,25 @@ def test_create_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: client.create_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateCustomJobRequest() - @pytest.mark.asyncio -async def test_create_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest -): +async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -546,16 +460,14 @@ async def test_create_custom_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_custom_job(request) @@ -568,9 +480,9 @@ async def test_create_custom_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -581,17 +493,19 @@ async def test_create_custom_job_async_from_dict(): def test_create_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: call.return_value = gca_custom_job.CustomJob() client.create_custom_job(request) @@ -603,25 +517,28 @@ def test_create_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) + type(client.transport.create_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) await client.create_custom_job(request) @@ -632,24 +549,29 @@ async def test_create_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -657,43 +579,45 @@ def test_create_custom_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') def test_create_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_custom_job( job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -701,30 +625,31 @@ async def test_create_custom_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') @pytest.mark.asyncio async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_custom_job( job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) -def test_get_custom_job( - transport: str = "grpc", request_type=job_service.GetCustomJobRequest -): +def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -732,12 +657,17 @@ def test_get_custom_job( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_custom_job(request) @@ -752,9 +682,9 @@ def test_get_custom_job( assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -767,24 +697,25 @@ def test_get_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: client.get_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetCustomJobRequest() - @pytest.mark.asyncio -async def test_get_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest -): +async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -792,15 +723,15 @@ async def test_get_custom_job_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_custom_job(request) @@ -813,9 +744,9 @@ async def test_get_custom_job_async( # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -826,15 +757,19 @@ async def test_get_custom_job_async_from_dict(): def test_get_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: call.return_value = custom_job.CustomJob() client.get_custom_job(request) @@ -846,23 +781,28 @@ def test_get_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) await client.get_custom_job(request) @@ -873,81 +813,99 @@ async def test_get_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_custom_job(name="name_value",) + client.get_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", + job_service.GetCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_custom_job(name="name_value",) + response = await client.get_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", + job_service.GetCustomJobRequest(), + name='name_value', ) -def test_list_custom_jobs( - transport: str = "grpc", request_type=job_service.ListCustomJobsRequest -): +def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -955,10 +913,13 @@ def test_list_custom_jobs( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_custom_jobs(request) @@ -973,7 +934,7 @@ def test_list_custom_jobs( assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_custom_jobs_from_dict(): @@ -984,24 +945,25 @@ def test_list_custom_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: client.list_custom_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListCustomJobsRequest() - @pytest.mark.asyncio -async def test_list_custom_jobs_async( - transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest -): +async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListCustomJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1009,11 +971,13 @@ async def test_list_custom_jobs_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_custom_jobs(request) @@ -1026,7 +990,7 @@ async def test_list_custom_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1035,15 +999,19 @@ async def test_list_custom_jobs_async_from_dict(): def test_list_custom_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: call.return_value = job_service.ListCustomJobsResponse() client.list_custom_jobs(request) @@ -1055,23 +1023,28 @@ def test_list_custom_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) await client.list_custom_jobs(request) @@ -1082,81 +1055,104 @@ async def test_list_custom_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_custom_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_custom_jobs(parent="parent_value",) + client.list_custom_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_custom_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", + job_service.ListCustomJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_custom_jobs(parent="parent_value",) + response = await client.list_custom_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", + job_service.ListCustomJobsRequest(), + parent='parent_value', ) def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1165,21 +1161,32 @@ def test_list_custom_jobs_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_custom_jobs(request={}) @@ -1187,14 +1194,18 @@ def test_list_custom_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in results) - + assert all(isinstance(i, custom_job.CustomJob) + for i in results) def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1203,30 +1214,40 @@ def test_list_custom_jobs_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1235,35 +1256,46 @@ async def test_list_custom_jobs_async_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in responses) - + assert all(isinstance(i, custom_job.CustomJob) + for i in responses) @pytest.mark.asyncio async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1272,29 +1304,37 @@ async def test_list_custom_jobs_async_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_custom_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_custom_job( - transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest -): +def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1303,10 +1343,10 @@ def test_delete_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_custom_job(request) @@ -1328,26 +1368,25 @@ def test_delete_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: client.delete_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteCustomJobRequest() - @pytest.mark.asyncio -async def test_delete_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest -): +async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1356,11 +1395,11 @@ async def test_delete_custom_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_custom_job(request) @@ -1381,18 +1420,20 @@ async def test_delete_custom_job_async_from_dict(): def test_delete_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_custom_job(request) @@ -1403,25 +1444,28 @@ def test_delete_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_custom_job(request) @@ -1432,85 +1476,101 @@ async def test_delete_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_custom_job(name="name_value",) + client.delete_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", + job_service.DeleteCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_custom_job(name="name_value",) + response = await client.delete_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", + job_service.DeleteCustomJobRequest(), + name='name_value', ) -def test_cancel_custom_job( - transport: str = "grpc", request_type=job_service.CancelCustomJobRequest -): +def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1519,8 +1579,8 @@ def test_cancel_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1544,26 +1604,25 @@ def test_cancel_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: client.cancel_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelCustomJobRequest() - @pytest.mark.asyncio -async def test_cancel_custom_job_async( - transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest -): +async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1572,8 +1631,8 @@ async def test_cancel_custom_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1595,17 +1654,19 @@ async def test_cancel_custom_job_async_from_dict(): def test_cancel_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: call.return_value = None client.cancel_custom_job(request) @@ -1617,22 +1678,27 @@ def test_cancel_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_custom_job(request) @@ -1644,83 +1710,99 @@ async def test_cancel_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_custom_job(name="name_value",) + client.cancel_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", + job_service.CancelCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_custom_job(name="name_value",) + response = await client.cancel_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", + job_service.CancelCustomJobRequest(), + name='name_value', ) -def test_create_data_labeling_job( - transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest -): +def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1729,19 +1811,28 @@ def test_create_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, - specialist_pools=["specialist_pools_value"], + + specialist_pools=['specialist_pools_value'], + ) response = client.create_data_labeling_job(request) @@ -1756,23 +1847,23 @@ def test_create_data_labeling_job( assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] def test_create_data_labeling_job_from_dict(): @@ -1783,27 +1874,25 @@ def test_create_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: client.create_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_create_data_labeling_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CreateDataLabelingJobRequest, -): +async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1812,22 +1901,20 @@ async def test_create_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) response = await client.create_data_labeling_job(request) @@ -1840,23 +1927,23 @@ async def test_create_data_labeling_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] @pytest.mark.asyncio @@ -1865,17 +1952,19 @@ async def test_create_data_labeling_job_async_from_dict(): def test_create_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: call.return_value = gca_data_labeling_job.DataLabelingJob() client.create_data_labeling_job(request) @@ -1887,25 +1976,28 @@ def test_create_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) + type(client.transport.create_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) await client.create_data_labeling_job(request) @@ -1916,24 +2008,29 @@ async def test_create_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -1941,45 +2038,45 @@ def test_create_data_labeling_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -1987,32 +2084,31 @@ async def test_create_data_labeling_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) -def test_get_data_labeling_job( - transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest -): +def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2021,19 +2117,28 @@ def test_get_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, - specialist_pools=["specialist_pools_value"], + + specialist_pools=['specialist_pools_value'], + ) response = client.get_data_labeling_job(request) @@ -2048,23 +2153,23 @@ def test_get_data_labeling_job( assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] def test_get_data_labeling_job_from_dict(): @@ -2075,26 +2180,25 @@ def test_get_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: client.get_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_get_data_labeling_job_async( - transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest -): +async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2103,22 +2207,20 @@ async def test_get_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) response = await client.get_data_labeling_job(request) @@ -2131,23 +2233,23 @@ async def test_get_data_labeling_job_async( # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] @pytest.mark.asyncio @@ -2156,17 +2258,19 @@ async def test_get_data_labeling_job_async_from_dict(): def test_get_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: call.return_value = data_labeling_job.DataLabelingJob() client.get_data_labeling_job(request) @@ -2178,25 +2282,28 @@ def test_get_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) + type(client.transport.get_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) await client.get_data_labeling_job(request) @@ -2207,85 +2314,99 @@ async def test_get_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_data_labeling_job(name="name_value",) + client.get_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", + job_service.GetDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_data_labeling_job(name="name_value",) + response = await client.get_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", + job_service.GetDataLabelingJobRequest(), + name='name_value', ) -def test_list_data_labeling_jobs( - transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest -): +def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2294,11 +2415,12 @@ def test_list_data_labeling_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_data_labeling_jobs(request) @@ -2313,7 +2435,7 @@ def test_list_data_labeling_jobs( assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_data_labeling_jobs_from_dict(): @@ -2324,27 +2446,25 @@ def test_list_data_labeling_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: client.list_data_labeling_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListDataLabelingJobsRequest() - @pytest.mark.asyncio -async def test_list_data_labeling_jobs_async( - transport: str = "grpc_asyncio", - request_type=job_service.ListDataLabelingJobsRequest, -): +async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListDataLabelingJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2353,14 +2473,12 @@ async def test_list_data_labeling_jobs_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_data_labeling_jobs(request) @@ -2373,7 +2491,7 @@ async def test_list_data_labeling_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2382,17 +2500,19 @@ async def test_list_data_labeling_jobs_async_from_dict(): def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: call.return_value = job_service.ListDataLabelingJobsResponse() client.list_data_labeling_jobs(request) @@ -2404,25 +2524,28 @@ def test_list_data_labeling_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) await client.list_data_labeling_jobs(request) @@ -2433,87 +2556,104 @@ async def test_list_data_labeling_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_labeling_jobs(parent="parent_value",) + client.list_data_labeling_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs(parent="parent_value",) + response = await client.list_data_labeling_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', ) def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2522,14 +2662,17 @@ def test_list_data_labeling_jobs_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2542,7 +2685,9 @@ def test_list_data_labeling_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_data_labeling_jobs(request={}) @@ -2550,16 +2695,18 @@ def test_list_data_labeling_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) - + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in results) def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2568,14 +2715,17 @@ def test_list_data_labeling_jobs_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2586,20 +2736,19 @@ def test_list_data_labeling_jobs_pages(): RuntimeError, ) pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2608,14 +2757,17 @@ async def test_list_data_labeling_jobs_async_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2626,25 +2778,25 @@ async def test_list_data_labeling_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) - + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in responses) @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2653,14 +2805,17 @@ async def test_list_data_labeling_jobs_async_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2673,15 +2828,14 @@ async def test_list_data_labeling_jobs_async_pages(): pages = [] async for page_ in (await client.list_data_labeling_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job( - transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest -): +def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2690,10 +2844,10 @@ def test_delete_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_data_labeling_job(request) @@ -2715,27 +2869,25 @@ def test_delete_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: client.delete_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_delete_data_labeling_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.DeleteDataLabelingJobRequest, -): +async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2744,11 +2896,11 @@ async def test_delete_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_data_labeling_job(request) @@ -2769,18 +2921,20 @@ async def test_delete_data_labeling_job_async_from_dict(): def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_data_labeling_job(request) @@ -2791,25 +2945,28 @@ def test_delete_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_data_labeling_job(request) @@ -2820,85 +2977,101 @@ async def test_delete_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_data_labeling_job(name="name_value",) + client.delete_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", + job_service.DeleteDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_data_labeling_job(name="name_value",) + response = await client.delete_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", + job_service.DeleteDataLabelingJobRequest(), + name='name_value', ) -def test_cancel_data_labeling_job( - transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest -): +def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2907,8 +3080,8 @@ def test_cancel_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2932,27 +3105,25 @@ def test_cancel_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: client.cancel_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelDataLabelingJobRequest() - @pytest.mark.asyncio -async def test_cancel_data_labeling_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CancelDataLabelingJobRequest, -): +async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2961,8 +3132,8 @@ async def test_cancel_data_labeling_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -2984,17 +3155,19 @@ async def test_cancel_data_labeling_job_async_from_dict(): def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: call.return_value = None client.cancel_data_labeling_job(request) @@ -3006,22 +3179,27 @@ def test_cancel_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_data_labeling_job(request) @@ -3033,84 +3211,99 @@ async def test_cancel_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_data_labeling_job(name="name_value",) + client.cancel_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", + job_service.CancelDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job(name="name_value",) + response = await client.cancel_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", + job_service.CancelDataLabelingJobRequest(), + name='name_value', ) -def test_create_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): +def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3119,16 +3312,22 @@ def test_create_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_hyperparameter_tuning_job(request) @@ -3143,9 +3342,9 @@ def test_create_hyperparameter_tuning_job( assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3164,27 +3363,25 @@ def test_create_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: client.create_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): +async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3193,19 +3390,17 @@ async def test_create_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_hyperparameter_tuning_job(request) @@ -3218,9 +3413,9 @@ async def test_create_hyperparameter_tuning_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3237,17 +3432,19 @@ async def test_create_hyperparameter_tuning_job_async_from_dict(): def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() client.create_hyperparameter_tuning_job(request) @@ -3259,25 +3456,28 @@ def test_create_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) await client.create_hyperparameter_tuning_job(request) @@ -3288,26 +3488,29 @@ async def test_create_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -3315,51 +3518,45 @@ def test_create_hyperparameter_tuning_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -3367,36 +3564,31 @@ async def test_create_hyperparameter_tuning_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) -def test_get_hyperparameter_tuning_job( - transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest -): +def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3405,16 +3597,22 @@ def test_get_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_hyperparameter_tuning_job(request) @@ -3429,9 +3627,9 @@ def test_get_hyperparameter_tuning_job( assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3450,27 +3648,25 @@ def test_get_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: client.get_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.GetHyperparameterTuningJobRequest, -): +async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3479,19 +3675,17 @@ async def test_get_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_hyperparameter_tuning_job(request) @@ -3504,9 +3698,9 @@ async def test_get_hyperparameter_tuning_job_async( # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3523,17 +3717,19 @@ async def test_get_hyperparameter_tuning_job_async_from_dict(): def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() client.get_hyperparameter_tuning_job(request) @@ -3545,25 +3741,28 @@ def test_get_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) await client.get_hyperparameter_tuning_job(request) @@ -3574,86 +3773,99 @@ async def test_get_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job(name="name_value",) + client.get_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job(name="name_value",) + response = await client.get_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', ) -def test_list_hyperparameter_tuning_jobs( - transport: str = "grpc", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): +def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3662,11 +3874,12 @@ def test_list_hyperparameter_tuning_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_hyperparameter_tuning_jobs(request) @@ -3681,7 +3894,7 @@ def test_list_hyperparameter_tuning_jobs( assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_hyperparameter_tuning_jobs_from_dict(): @@ -3692,27 +3905,25 @@ def test_list_hyperparameter_tuning_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: client.list_hyperparameter_tuning_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListHyperparameterTuningJobsRequest() - @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async( - transport: str = "grpc_asyncio", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): +async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListHyperparameterTuningJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3721,14 +3932,12 @@ async def test_list_hyperparameter_tuning_jobs_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_hyperparameter_tuning_jobs(request) @@ -3741,7 +3950,7 @@ async def test_list_hyperparameter_tuning_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -3750,17 +3959,19 @@ async def test_list_hyperparameter_tuning_jobs_async_from_dict(): def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: call.return_value = job_service.ListHyperparameterTuningJobsResponse() client.list_hyperparameter_tuning_jobs(request) @@ -3772,25 +3983,28 @@ def test_list_hyperparameter_tuning_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) await client.list_hyperparameter_tuning_jobs(request) @@ -3801,87 +4015,104 @@ async def test_list_hyperparameter_tuning_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs(parent="parent_value",) + client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) + response = await client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', ) def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3890,16 +4121,17 @@ def test_list_hyperparameter_tuning_jobs_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3912,7 +4144,9 @@ def test_list_hyperparameter_tuning_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_hyperparameter_tuning_jobs(request={}) @@ -3920,19 +4154,18 @@ def test_list_hyperparameter_tuning_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results - ) - + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results) def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3941,16 +4174,17 @@ def test_list_hyperparameter_tuning_jobs_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3961,20 +4195,19 @@ def test_list_hyperparameter_tuning_jobs_pages(): RuntimeError, ) pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3983,16 +4216,17 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4003,28 +4237,25 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses - ) - + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4033,16 +4264,17 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4053,20 +4285,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( - await client.list_hyperparameter_tuning_jobs(request={}) - ).pages: + async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): +def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4075,10 +4303,10 @@ def test_delete_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_hyperparameter_tuning_job(request) @@ -4100,27 +4328,25 @@ def test_delete_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: client.delete_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): +async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4129,11 +4355,11 @@ async def test_delete_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_hyperparameter_tuning_job(request) @@ -4154,18 +4380,20 @@ async def test_delete_hyperparameter_tuning_job_async_from_dict(): def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_hyperparameter_tuning_job(request) @@ -4176,25 +4404,28 @@ def test_delete_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_hyperparameter_tuning_job(request) @@ -4205,86 +4436,101 @@ async def test_delete_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job(name="name_value",) + client.delete_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job(name="name_value",) + response = await client.delete_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', ) -def test_cancel_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): +def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4293,8 +4539,8 @@ def test_cancel_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -4318,27 +4564,25 @@ def test_cancel_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: client.cancel_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelHyperparameterTuningJobRequest() - @pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): +async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4347,8 +4591,8 @@ async def test_cancel_hyperparameter_tuning_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -4370,17 +4614,19 @@ async def test_cancel_hyperparameter_tuning_job_async_from_dict(): def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: call.return_value = None client.cancel_hyperparameter_tuning_job(request) @@ -4392,22 +4638,27 @@ def test_cancel_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_hyperparameter_tuning_job(request) @@ -4419,83 +4670,99 @@ async def test_cancel_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job(name="name_value",) + client.cancel_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job(name="name_value",) + response = await client.cancel_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', ) -def test_create_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest -): +def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4504,15 +4771,20 @@ def test_create_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", + name='name_value', + + display_name='display_name_value', + + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_batch_prediction_job(request) @@ -4527,11 +4799,11 @@ def test_create_batch_prediction_job( assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True @@ -4546,27 +4818,25 @@ def test_create_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: client.create_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateBatchPredictionJobRequest() - @pytest.mark.asyncio -async def test_create_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CreateBatchPredictionJobRequest, -): +async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4575,18 +4845,16 @@ async def test_create_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_batch_prediction_job(request) @@ -4599,11 +4867,11 @@ async def test_create_batch_prediction_job_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True @@ -4616,17 +4884,19 @@ async def test_create_batch_prediction_job_async_from_dict(): def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: call.return_value = gca_batch_prediction_job.BatchPredictionJob() client.create_batch_prediction_job(request) @@ -4638,25 +4908,28 @@ def test_create_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) + type(client.transport.create_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) await client.create_batch_prediction_job(request) @@ -4667,26 +4940,29 @@ async def test_create_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -4694,51 +4970,45 @@ def test_create_batch_prediction_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -4746,36 +5016,31 @@ async def test_create_batch_prediction_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) -def test_get_batch_prediction_job( - transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest -): +def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4784,15 +5049,20 @@ def test_get_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", + name='name_value', + + display_name='display_name_value', + + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_batch_prediction_job(request) @@ -4807,11 +5077,11 @@ def test_get_batch_prediction_job( assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True @@ -4826,27 +5096,25 @@ def test_get_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: client.get_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetBatchPredictionJobRequest() - @pytest.mark.asyncio -async def test_get_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.GetBatchPredictionJobRequest, -): +async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4855,18 +5123,16 @@ async def test_get_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_batch_prediction_job(request) @@ -4879,11 +5145,11 @@ async def test_get_batch_prediction_job_async( # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True @@ -4896,17 +5162,19 @@ async def test_get_batch_prediction_job_async_from_dict(): def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: call.return_value = batch_prediction_job.BatchPredictionJob() client.get_batch_prediction_job(request) @@ -4918,25 +5186,28 @@ def test_get_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) + type(client.transport.get_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) await client.get_batch_prediction_job(request) @@ -4947,85 +5218,99 @@ async def test_get_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_batch_prediction_job(name="name_value",) + client.get_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", + job_service.GetBatchPredictionJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_batch_prediction_job(name="name_value",) + response = await client.get_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", + job_service.GetBatchPredictionJobRequest(), + name='name_value', ) -def test_list_batch_prediction_jobs( - transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest -): +def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5034,11 +5319,12 @@ def test_list_batch_prediction_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_batch_prediction_jobs(request) @@ -5053,7 +5339,7 @@ def test_list_batch_prediction_jobs( assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_batch_prediction_jobs_from_dict(): @@ -5064,27 +5350,25 @@ def test_list_batch_prediction_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: client.list_batch_prediction_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListBatchPredictionJobsRequest() - @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async( - transport: str = "grpc_asyncio", - request_type=job_service.ListBatchPredictionJobsRequest, -): +async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListBatchPredictionJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5093,14 +5377,12 @@ async def test_list_batch_prediction_jobs_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_batch_prediction_jobs(request) @@ -5113,7 +5395,7 @@ async def test_list_batch_prediction_jobs_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -5122,17 +5404,19 @@ async def test_list_batch_prediction_jobs_async_from_dict(): def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: call.return_value = job_service.ListBatchPredictionJobsResponse() client.list_batch_prediction_jobs(request) @@ -5144,25 +5428,28 @@ def test_list_batch_prediction_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) await client.list_batch_prediction_jobs(request) @@ -5173,87 +5460,104 @@ async def test_list_batch_prediction_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_batch_prediction_jobs(parent="parent_value",) + client.list_batch_prediction_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs(parent="parent_value",) + response = await client.list_batch_prediction_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', ) def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5262,14 +5566,17 @@ def test_list_batch_prediction_jobs_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5282,7 +5589,9 @@ def test_list_batch_prediction_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_batch_prediction_jobs(request={}) @@ -5290,18 +5599,18 @@ def test_list_batch_prediction_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results - ) - + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in results) def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5310,14 +5619,17 @@ def test_list_batch_prediction_jobs_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5328,20 +5640,19 @@ def test_list_batch_prediction_jobs_pages(): RuntimeError, ) pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5350,14 +5661,17 @@ async def test_list_batch_prediction_jobs_async_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5368,64 +5682,2454 @@ async def test_list_batch_prediction_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) - assert len(responses) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses + assert len(responses) == 6 + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in responses) + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_pages(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='abc', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], + next_page_token='def', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_batch_prediction_job_from_dict(): + test_delete_batch_prediction_job(request_type=dict) + + +def test_delete_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + client.delete_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteBatchPredictionJobRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async_from_dict(): + await test_delete_batch_prediction_job_async(request_type=dict) + + +def test_delete_batch_prediction_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_batch_prediction_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_batch_prediction_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', + ) + + +def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_batch_prediction_job_from_dict(): + test_cancel_batch_prediction_job(request_type=dict) + + +def test_cancel_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + client.cancel_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelBatchPredictionJobRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async_from_dict(): + await test_cancel_batch_prediction_job_async(request_type=dict) + + +def test_cancel_batch_prediction_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + call.return_value = None + + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_cancel_batch_prediction_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_cancel_batch_prediction_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), + name='name_value', + ) + + +def test_create_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.CreateModelDeploymentMonitoringJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name='name_value', + + display_name='display_name_value', + + endpoint='endpoint_value', + + state=job_state.JobState.JOB_STATE_QUEUED, + + schedule_state=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, + + predict_instance_schema_uri='predict_instance_schema_uri_value', + + analysis_instance_schema_uri='analysis_instance_schema_uri_value', + + ) + + response = client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.endpoint == 'endpoint_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.schedule_state == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + + assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + + assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + + +def test_create_model_deployment_monitoring_job_from_dict(): + test_create_model_deployment_monitoring_job(request_type=dict) + + +def test_create_model_deployment_monitoring_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_model_deployment_monitoring_job), + '__call__') as call: + client.create_model_deployment_monitoring_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateModelDeploymentMonitoringJobRequest() + +@pytest.mark.asyncio +async def test_create_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateModelDeploymentMonitoringJobRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name='name_value', + display_name='display_name_value', + endpoint='endpoint_value', + state=job_state.JobState.JOB_STATE_QUEUED, + schedule_state=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, + predict_instance_schema_uri='predict_instance_schema_uri_value', + analysis_instance_schema_uri='analysis_instance_schema_uri_value', + )) + + response = await client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.endpoint == 'endpoint_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.schedule_state == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + + assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + + assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + + +@pytest.mark.asyncio +async def test_create_model_deployment_monitoring_job_async_from_dict(): + await test_create_model_deployment_monitoring_job_async(request_type=dict) + + +def test_create_model_deployment_monitoring_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateModelDeploymentMonitoringJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + + client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_model_deployment_monitoring_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateModelDeploymentMonitoringJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + + await client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_model_deployment_monitoring_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_model_deployment_monitoring_job( + parent='parent_value', + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + + +def test_create_model_deployment_monitoring_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_model_deployment_monitoring_job( + job_service.CreateModelDeploymentMonitoringJobRequest(), + parent='parent_value', + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_model_deployment_monitoring_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_model_deployment_monitoring_job( + parent='parent_value', + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + + +@pytest.mark.asyncio +async def test_create_model_deployment_monitoring_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_model_deployment_monitoring_job( + job_service.CreateModelDeploymentMonitoringJobRequest(), + parent='parent_value', + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + ) + + +def test_search_model_deployment_monitoring_stats_anomalies(transport: str = 'grpc', request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_search_model_deployment_monitoring_stats_anomalies_from_dict(): + test_search_model_deployment_monitoring_stats_anomalies(request_type=dict) + + +def test_search_model_deployment_monitoring_stats_anomalies_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + client.search_model_deployment_monitoring_stats_anomalies() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_async(transport: str = 'grpc_asyncio', request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_async_from_dict(): + await test_search_model_deployment_monitoring_stats_anomalies_async(request_type=dict) + + +def test_search_model_deployment_monitoring_stats_anomalies_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + request.model_deployment_monitoring_job = 'model_deployment_monitoring_job/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model_deployment_monitoring_job=model_deployment_monitoring_job/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + request.model_deployment_monitoring_job = 'model_deployment_monitoring_job/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse()) + + await client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model_deployment_monitoring_job=model_deployment_monitoring_job/value', + ) in kw['metadata'] + + +def test_search_model_deployment_monitoring_stats_anomalies_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.search_model_deployment_monitoring_stats_anomalies( + model_deployment_monitoring_job='model_deployment_monitoring_job_value', + deployed_model_id='deployed_model_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].model_deployment_monitoring_job == 'model_deployment_monitoring_job_value' + + assert args[0].deployed_model_id == 'deployed_model_id_value' + + +def test_search_model_deployment_monitoring_stats_anomalies_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.search_model_deployment_monitoring_stats_anomalies( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(), + model_deployment_monitoring_job='model_deployment_monitoring_job_value', + deployed_model_id='deployed_model_id_value', + ) + + +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.search_model_deployment_monitoring_stats_anomalies( + model_deployment_monitoring_job='model_deployment_monitoring_job_value', + deployed_model_id='deployed_model_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].model_deployment_monitoring_job == 'model_deployment_monitoring_job_value' + + assert args[0].deployed_model_id == 'deployed_model_id_value' + + +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.search_model_deployment_monitoring_stats_anomalies( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(), + model_deployment_monitoring_job='model_deployment_monitoring_job_value', + deployed_model_id='deployed_model_id_value', + ) + + +def test_search_model_deployment_monitoring_stats_anomalies_pager(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='abc', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[], + next_page_token='def', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='ghi', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('model_deployment_monitoring_job', ''), + )), + ) + pager = client.search_model_deployment_monitoring_stats_anomalies(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies) + for i in results) + +def test_search_model_deployment_monitoring_stats_anomalies_pages(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='abc', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[], + next_page_token='def', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='ghi', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + ), + RuntimeError, + ) + pages = list(client.search_model_deployment_monitoring_stats_anomalies(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_async_pager(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='abc', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[], + next_page_token='def', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='ghi', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + ), + RuntimeError, + ) + async_pager = await client.search_model_deployment_monitoring_stats_anomalies(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies) + for i in responses) + +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_async_pages(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='abc', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[], + next_page_token='def', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + next_page_token='ghi', + ), + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + monitoring_stats=[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.search_model_deployment_monitoring_stats_anomalies(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_get_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.GetModelDeploymentMonitoringJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name='name_value', + + display_name='display_name_value', + + endpoint='endpoint_value', + + state=job_state.JobState.JOB_STATE_QUEUED, + + schedule_state=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, + + predict_instance_schema_uri='predict_instance_schema_uri_value', + + analysis_instance_schema_uri='analysis_instance_schema_uri_value', + + ) + + response = client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.endpoint == 'endpoint_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.schedule_state == model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + + assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + + assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + + +def test_get_model_deployment_monitoring_job_from_dict(): + test_get_model_deployment_monitoring_job(request_type=dict) + + +def test_get_model_deployment_monitoring_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_deployment_monitoring_job), + '__call__') as call: + client.get_model_deployment_monitoring_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetModelDeploymentMonitoringJobRequest() + +@pytest.mark.asyncio +async def test_get_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetModelDeploymentMonitoringJobRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name='name_value', + display_name='display_name_value', + endpoint='endpoint_value', + state=job_state.JobState.JOB_STATE_QUEUED, + schedule_state=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, + predict_instance_schema_uri='predict_instance_schema_uri_value', + analysis_instance_schema_uri='analysis_instance_schema_uri_value', + )) + + response = await client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.endpoint == 'endpoint_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.schedule_state == model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + + assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + + assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + + +@pytest.mark.asyncio +async def test_get_model_deployment_monitoring_job_async_from_dict(): + await test_get_model_deployment_monitoring_job_async(request_type=dict) + + +def test_get_model_deployment_monitoring_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetModelDeploymentMonitoringJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + + client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_model_deployment_monitoring_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetModelDeploymentMonitoringJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + + await client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_model_deployment_monitoring_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model_deployment_monitoring_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_model_deployment_monitoring_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_deployment_monitoring_job( + job_service.GetModelDeploymentMonitoringJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_model_deployment_monitoring_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model_deployment_monitoring_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_model_deployment_monitoring_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model_deployment_monitoring_job( + job_service.GetModelDeploymentMonitoringJobRequest(), + name='name_value', + ) + + +def test_list_model_deployment_monitoring_jobs(transport: str = 'grpc', request_type=job_service.ListModelDeploymentMonitoringJobsRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListModelDeploymentMonitoringJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListModelDeploymentMonitoringJobsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_model_deployment_monitoring_jobs_from_dict(): + test_list_model_deployment_monitoring_jobs(request_type=dict) + + +def test_list_model_deployment_monitoring_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + client.list_model_deployment_monitoring_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListModelDeploymentMonitoringJobsRequest() + +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListModelDeploymentMonitoringJobsRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListModelDeploymentMonitoringJobsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListModelDeploymentMonitoringJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelDeploymentMonitoringJobsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_async_from_dict(): + await test_list_model_deployment_monitoring_jobs_async(request_type=dict) + + +def test_list_model_deployment_monitoring_jobs_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListModelDeploymentMonitoringJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse() + + client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListModelDeploymentMonitoringJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListModelDeploymentMonitoringJobsResponse()) + + await client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_model_deployment_monitoring_jobs_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_model_deployment_monitoring_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_model_deployment_monitoring_jobs_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_deployment_monitoring_jobs( + job_service.ListModelDeploymentMonitoringJobsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListModelDeploymentMonitoringJobsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_model_deployment_monitoring_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_model_deployment_monitoring_jobs( + job_service.ListModelDeploymentMonitoringJobsRequest(), + parent='parent_value', + ) + + +def test_list_model_deployment_monitoring_jobs_pager(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='abc', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[], + next_page_token='def', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='ghi', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_model_deployment_monitoring_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + for i in results) + +def test_list_model_deployment_monitoring_jobs_pages(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='abc', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[], + next_page_token='def', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='ghi', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_deployment_monitoring_jobs(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_async_pager(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='abc', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[], + next_page_token='def', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='ghi', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_model_deployment_monitoring_jobs(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + for i in responses) + +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_async_pages(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_deployment_monitoring_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='abc', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[], + next_page_token='def', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + next_page_token='ghi', + ), + job_service.ListModelDeploymentMonitoringJobsResponse( + model_deployment_monitoring_jobs=[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_model_deployment_monitoring_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.UpdateModelDeploymentMonitoringJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_model_deployment_monitoring_job_from_dict(): + test_update_model_deployment_monitoring_job(request_type=dict) + + +def test_update_model_deployment_monitoring_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model_deployment_monitoring_job), + '__call__') as call: + client.update_model_deployment_monitoring_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() + +@pytest.mark.asyncio +async def test_update_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.UpdateModelDeploymentMonitoringJobRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_model_deployment_monitoring_job_async_from_dict(): + await test_update_model_deployment_monitoring_job_async(request_type=dict) + + +def test_update_model_deployment_monitoring_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.UpdateModelDeploymentMonitoringJobRequest() + request.model_deployment_monitoring_job.name = 'model_deployment_monitoring_job.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model_deployment_monitoring_job.name=model_deployment_monitoring_job.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_model_deployment_monitoring_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.UpdateModelDeploymentMonitoringJobRequest() + request.model_deployment_monitoring_job.name = 'model_deployment_monitoring_job.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model_deployment_monitoring_job.name=model_deployment_monitoring_job.name/value', + ) in kw['metadata'] + + +def test_update_model_deployment_monitoring_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_model_deployment_monitoring_job( + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_model_deployment_monitoring_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_model_deployment_monitoring_job( + job_service.UpdateModelDeploymentMonitoringJobRequest(), + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_model_deployment_monitoring_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_model_deployment_monitoring_job( + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_model_deployment_monitoring_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_model_deployment_monitoring_job( + job_service.UpdateModelDeploymentMonitoringJobRequest(), + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.DeleteModelDeploymentMonitoringJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_model_deployment_monitoring_job_from_dict(): + test_delete_model_deployment_monitoring_job(request_type=dict) + + +def test_delete_model_deployment_monitoring_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model_deployment_monitoring_job), + '__call__') as call: + client.delete_model_deployment_monitoring_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteModelDeploymentMonitoringJobRequest() + +@pytest.mark.asyncio +async def test_delete_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteModelDeploymentMonitoringJobRequest): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteModelDeploymentMonitoringJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_model_deployment_monitoring_job_async_from_dict(): + await test_delete_model_deployment_monitoring_job_async(request_type=dict) + + +def test_delete_model_deployment_monitoring_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteModelDeploymentMonitoringJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_model_deployment_monitoring_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteModelDeploymentMonitoringJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_model_deployment_monitoring_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_model_deployment_monitoring_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_model_deployment_monitoring_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_model_deployment_monitoring_job( + job_service.DeleteModelDeploymentMonitoringJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_model_deployment_monitoring_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model_deployment_monitoring_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_model_deployment_monitoring_job( + name='name_value', ) + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) +async def test_delete_model_deployment_monitoring_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_model_deployment_monitoring_job( + job_service.DeleteModelDeploymentMonitoringJobRequest(), + name='name_value', ) - pages = [] - async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job( - transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest -): +def test_pause_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.PauseModelDeploymentMonitoringJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5434,52 +8138,50 @@ def test_delete_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.pause_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = None - response = client.delete_batch_prediction_job(request) + response = client.pause_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == job_service.DeleteBatchPredictionJobRequest() + assert args[0] == job_service.PauseModelDeploymentMonitoringJobRequest() # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert response is None -def test_delete_batch_prediction_job_from_dict(): - test_delete_batch_prediction_job(request_type=dict) +def test_pause_model_deployment_monitoring_job_from_dict(): + test_pause_model_deployment_monitoring_job(request_type=dict) -def test_delete_batch_prediction_job_empty_call(): +def test_pause_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - client.delete_batch_prediction_job() + type(client.transport.pause_model_deployment_monitoring_job), + '__call__') as call: + client.pause_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == job_service.DeleteBatchPredictionJobRequest() - + assert args[0] == job_service.PauseModelDeploymentMonitoringJobRequest() @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.DeleteBatchPredictionJobRequest, -): +async def test_pause_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.PauseModelDeploymentMonitoringJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5488,45 +8190,45 @@ async def test_delete_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.pause_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - response = await client.delete_batch_prediction_job(request) + response = await client.pause_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == job_service.DeleteBatchPredictionJobRequest() + assert args[0] == job_service.PauseModelDeploymentMonitoringJobRequest() # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert response is None @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async_from_dict(): - await test_delete_batch_prediction_job_async(request_type=dict) +async def test_pause_model_deployment_monitoring_job_async_from_dict(): + await test_pause_model_deployment_monitoring_job_async(request_type=dict) -def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) +def test_pause_model_deployment_monitoring_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" + request = job_service.PauseModelDeploymentMonitoringJobRequest() + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.pause_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = None - client.delete_batch_prediction_job(request) + client.pause_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -5535,27 +8237,30 @@ def test_delete_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio -async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) +async def test_pause_model_deployment_monitoring_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" + request = job_service.PauseModelDeploymentMonitoringJobRequest() + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.pause_model_deployment_monitoring_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - await client.delete_batch_prediction_job(request) + await client.pause_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -5564,85 +8269,99 @@ async def test_delete_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] -def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) +def test_pause_model_deployment_monitoring_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.pause_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_batch_prediction_job(name="name_value",) + client.pause_model_deployment_monitoring_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' -def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) +def test_pause_model_deployment_monitoring_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", + client.pause_model_deployment_monitoring_job( + job_service.PauseModelDeploymentMonitoringJobRequest(), + name='name_value', ) @pytest.mark.asyncio -async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) +async def test_pause_model_deployment_monitoring_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.pause_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = None - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job(name="name_value",) + response = await client.pause_model_deployment_monitoring_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio -async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) +async def test_pause_model_deployment_monitoring_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", + await client.pause_model_deployment_monitoring_job( + job_service.PauseModelDeploymentMonitoringJobRequest(), + name='name_value', ) -def test_cancel_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest -): +def test_resume_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.ResumeModelDeploymentMonitoringJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5651,52 +8370,50 @@ def test_cancel_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.resume_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None - response = client.cancel_batch_prediction_job(request) + response = client.resume_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == job_service.CancelBatchPredictionJobRequest() + assert args[0] == job_service.ResumeModelDeploymentMonitoringJobRequest() # Establish that the response is the type that we expect. assert response is None -def test_cancel_batch_prediction_job_from_dict(): - test_cancel_batch_prediction_job(request_type=dict) +def test_resume_model_deployment_monitoring_job_from_dict(): + test_resume_model_deployment_monitoring_job(request_type=dict) -def test_cancel_batch_prediction_job_empty_call(): +def test_resume_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: - client.cancel_batch_prediction_job() + type(client.transport.resume_model_deployment_monitoring_job), + '__call__') as call: + client.resume_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == job_service.CancelBatchPredictionJobRequest() - + assert args[0] == job_service.ResumeModelDeploymentMonitoringJobRequest() @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async( - transport: str = "grpc_asyncio", - request_type=job_service.CancelBatchPredictionJobRequest, -): +async def test_resume_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.ResumeModelDeploymentMonitoringJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5705,43 +8422,45 @@ async def test_cancel_batch_prediction_job_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.resume_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - response = await client.cancel_batch_prediction_job(request) + response = await client.resume_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == job_service.CancelBatchPredictionJobRequest() + assert args[0] == job_service.ResumeModelDeploymentMonitoringJobRequest() # Establish that the response is the type that we expect. assert response is None @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async_from_dict(): - await test_cancel_batch_prediction_job_async(request_type=dict) +async def test_resume_model_deployment_monitoring_job_async_from_dict(): + await test_resume_model_deployment_monitoring_job_async(request_type=dict) -def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) +def test_resume_model_deployment_monitoring_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" + request = job_service.ResumeModelDeploymentMonitoringJobRequest() + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.resume_model_deployment_monitoring_job), + '__call__') as call: call.return_value = None - client.cancel_batch_prediction_job(request) + client.resume_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -5750,25 +8469,30 @@ def test_cancel_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) +async def test_resume_model_deployment_monitoring_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" + request = job_service.ResumeModelDeploymentMonitoringJobRequest() + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.resume_model_deployment_monitoring_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - await client.cancel_batch_prediction_job(request) + await client.resume_model_deployment_monitoring_job(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -5777,75 +8501,92 @@ async def test_cancel_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] -def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) +def test_resume_model_deployment_monitoring_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.resume_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_batch_prediction_job(name="name_value",) + client.resume_model_deployment_monitoring_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' -def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) +def test_resume_model_deployment_monitoring_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", + client.resume_model_deployment_monitoring_job( + job_service.ResumeModelDeploymentMonitoringJobRequest(), + name='name_value', ) @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) +async def test_resume_model_deployment_monitoring_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.resume_model_deployment_monitoring_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job(name="name_value",) + response = await client.resume_model_deployment_monitoring_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) +async def test_resume_model_deployment_monitoring_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", + await client.resume_model_deployment_monitoring_job( + job_service.ResumeModelDeploymentMonitoringJobRequest(), + name='name_value', ) @@ -5856,7 +8597,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -5875,7 +8617,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -5903,13 +8646,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport,], -) +@pytest.mark.parametrize("transport_class", [ + transports.JobServiceGrpcTransport, + transports.JobServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -5917,8 +8660,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.JobServiceGrpcTransport,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.JobServiceGrpcTransport, + ) def test_job_service_base_transport_error(): @@ -5926,15 +8674,13 @@ def test_job_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_job_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -5943,27 +8689,35 @@ def test_job_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_custom_job", - "get_custom_job", - "list_custom_jobs", - "delete_custom_job", - "cancel_custom_job", - "create_data_labeling_job", - "get_data_labeling_job", - "list_data_labeling_jobs", - "delete_data_labeling_job", - "cancel_data_labeling_job", - "create_hyperparameter_tuning_job", - "get_hyperparameter_tuning_job", - "list_hyperparameter_tuning_jobs", - "delete_hyperparameter_tuning_job", - "cancel_hyperparameter_tuning_job", - "create_batch_prediction_job", - "get_batch_prediction_job", - "list_batch_prediction_jobs", - "delete_batch_prediction_job", - "cancel_batch_prediction_job", - ) + 'create_custom_job', + 'get_custom_job', + 'list_custom_jobs', + 'delete_custom_job', + 'cancel_custom_job', + 'create_data_labeling_job', + 'get_data_labeling_job', + 'list_data_labeling_jobs', + 'delete_data_labeling_job', + 'cancel_data_labeling_job', + 'create_hyperparameter_tuning_job', + 'get_hyperparameter_tuning_job', + 'list_hyperparameter_tuning_jobs', + 'delete_hyperparameter_tuning_job', + 'cancel_hyperparameter_tuning_job', + 'create_batch_prediction_job', + 'get_batch_prediction_job', + 'list_batch_prediction_jobs', + 'delete_batch_prediction_job', + 'cancel_batch_prediction_job', + 'create_model_deployment_monitoring_job', + 'search_model_deployment_monitoring_stats_anomalies', + 'get_model_deployment_monitoring_job', + 'list_model_deployment_monitoring_jobs', + 'update_model_deployment_monitoring_job', + 'delete_model_deployment_monitoring_job', + 'pause_model_deployment_monitoring_job', + 'resume_model_deployment_monitoring_job', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -5976,28 +8730,23 @@ def test_job_service_base_transport(): def test_job_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_job_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport() @@ -6006,11 +8755,11 @@ def test_job_service_base_transport_with_adc(): def test_job_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) JobServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -6018,22 +8767,19 @@ def test_job_service_auth_adc(): def test_job_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -6042,13 +8788,15 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class) transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -6063,40 +8811,38 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class) with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_job_service_host_with_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_job_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6104,11 +8850,12 @@ def test_job_service_grpc_transport_channel(): def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6117,17 +8864,12 @@ def test_job_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -6136,7 +8878,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -6152,7 +8894,9 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6166,20 +8910,17 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -6196,7 +8937,9 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6209,12 +8952,16 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): def test_job_service_grpc_lro_client(): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6222,12 +8969,16 @@ def test_job_service_grpc_lro_client(): def test_job_service_grpc_lro_async_client(): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6238,20 +8989,17 @@ def test_batch_prediction_job_path(): location = "clam" batch_prediction_job = "whelk" - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, location=location, batch_prediction_job=batch_prediction_job, - ) - actual = JobServiceClient.batch_prediction_job_path( - project, location, batch_prediction_job - ) + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) assert expected == actual def test_parse_batch_prediction_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", + } path = JobServiceClient.batch_prediction_job_path(**expected) @@ -6259,24 +9007,22 @@ def test_parse_batch_prediction_job_path(): actual = JobServiceClient.parse_batch_prediction_job_path(path) assert expected == actual - def test_custom_job_path(): project = "cuttlefish" location = "mussel" custom_job = "winkle" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) actual = JobServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", + } path = JobServiceClient.custom_job_path(**expected) @@ -6284,26 +9030,22 @@ def test_parse_custom_job_path(): actual = JobServiceClient.parse_custom_job_path(path) assert expected == actual - def test_data_labeling_job_path(): project = "squid" location = "clam" data_labeling_job = "whelk" - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) - actual = JobServiceClient.data_labeling_job_path( - project, location, data_labeling_job - ) + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) assert expected == actual def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", + } path = JobServiceClient.data_labeling_job_path(**expected) @@ -6311,24 +9053,22 @@ def test_parse_data_labeling_job_path(): actual = JobServiceClient.parse_data_labeling_job_path(path) assert expected == actual - def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = JobServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } path = JobServiceClient.dataset_path(**expected) @@ -6336,28 +9076,45 @@ def test_parse_dataset_path(): actual = JobServiceClient.parse_dataset_path(path) assert expected == actual - -def test_hyperparameter_tuning_job_path(): +def test_endpoint_path(): project = "squid" location = "clam" - hyperparameter_tuning_job = "whelk" + endpoint = "whelk" - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) - actual = JobServiceClient.hyperparameter_tuning_job_path( - project, location, hyperparameter_tuning_job - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + actual = JobServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + + } + path = JobServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_endpoint_path(path) + assert expected == actual + +def test_hyperparameter_tuning_job_path(): + project = "cuttlefish" + location = "mussel" + hyperparameter_tuning_job = "winkle" + + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) assert expected == actual def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "hyperparameter_tuning_job": "nudibranch", + "project": "nautilus", + "location": "scallop", + "hyperparameter_tuning_job": "abalone", + } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -6365,24 +9122,22 @@ def test_parse_hyperparameter_tuning_job_path(): actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) assert expected == actual - def test_model_path(): - project = "cuttlefish" - location = "mussel" - model = "winkle" + project = "squid" + location = "clam" + model = "whelk" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = JobServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "octopus", + "location": "oyster", + "model": "nudibranch", + } path = JobServiceClient.model_path(**expected) @@ -6390,6 +9145,28 @@ def test_parse_model_path(): actual = JobServiceClient.parse_model_path(path) assert expected == actual +def test_model_deployment_monitoring_job_path(): + project = "cuttlefish" + location = "mussel" + model_deployment_monitoring_job = "winkle" + + expected = "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format(project=project, location=location, model_deployment_monitoring_job=model_deployment_monitoring_job, ) + actual = JobServiceClient.model_deployment_monitoring_job_path(project, location, model_deployment_monitoring_job) + assert expected == actual + + +def test_parse_model_deployment_monitoring_job_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model_deployment_monitoring_job": "abalone", + + } + path = JobServiceClient.model_deployment_monitoring_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_model_deployment_monitoring_job_path(path) + assert expected == actual def test_trial_path(): project = "squid" @@ -6397,19 +9174,18 @@ def test_trial_path(): study = "whelk" trial = "octopus" - expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( - project=project, location=location, study=study, trial=trial, - ) + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) actual = JobServiceClient.trial_path(project, location, study, trial) assert expected == actual def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", + } path = JobServiceClient.trial_path(**expected) @@ -6417,20 +9193,18 @@ def test_parse_trial_path(): actual = JobServiceClient.parse_trial_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = JobServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "nautilus", + } path = JobServiceClient.common_billing_account_path(**expected) @@ -6438,18 +9212,18 @@ def test_parse_common_billing_account_path(): actual = JobServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = JobServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "abalone", + } path = JobServiceClient.common_folder_path(**expected) @@ -6457,18 +9231,18 @@ def test_parse_common_folder_path(): actual = JobServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = JobServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "clam", + } path = JobServiceClient.common_organization_path(**expected) @@ -6476,18 +9250,18 @@ def test_parse_common_organization_path(): actual = JobServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = JobServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "octopus", + } path = JobServiceClient.common_project_path(**expected) @@ -6495,22 +9269,20 @@ def test_parse_common_project_path(): actual = JobServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = JobServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "cuttlefish", + "location": "mussel", + } path = JobServiceClient.common_location_path(**expected) @@ -6522,19 +9294,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: transport_class = JobServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py new file mode 100644 index 0000000000..0a71403d33 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -0,0 +1,8524 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.metadata_service import MetadataServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.metadata_service import MetadataServiceClient +from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers +from google.cloud.aiplatform_v1beta1.services.metadata_service import transports +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import context as gca_context +from google.cloud.aiplatform_v1beta1.types import encryption_spec +from google.cloud.aiplatform_v1beta1.types import event +from google.cloud.aiplatform_v1beta1.types import execution +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import lineage_subgraph +from google.cloud.aiplatform_v1beta1.types import metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_schema as gca_metadata_schema +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import metadata_store +from google.cloud.aiplatform_v1beta1.types import metadata_store as gca_metadata_store +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert MetadataServiceClient._get_default_mtls_endpoint(None) is None + assert MetadataServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert MetadataServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert MetadataServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert MetadataServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert MetadataServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [ + MetadataServiceClient, + MetadataServiceAsyncClient, +]) +def test_metadata_service_client_from_service_account_info(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +@pytest.mark.parametrize("client_class", [ + MetadataServiceClient, + MetadataServiceAsyncClient, +]) +def test_metadata_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_metadata_service_client_get_transport_class(): + transport = MetadataServiceClient.get_transport_class() + available_transports = [ + transports.MetadataServiceGrpcTransport, + ] + assert transport in available_transports + + transport = MetadataServiceClient.get_transport_class("grpc") + assert transport == transports.MetadataServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(MetadataServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceClient)) +@mock.patch.object(MetadataServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceAsyncClient)) +def test_metadata_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(MetadataServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(MetadataServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc", "true"), + (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc", "false"), + (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(MetadataServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceClient)) +@mock.patch.object(MetadataServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_metadata_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_metadata_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_metadata_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_metadata_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = MetadataServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_metadata_store(transport: str = 'grpc', request_type=metadata_service.CreateMetadataStoreRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateMetadataStoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_metadata_store_from_dict(): + test_create_metadata_store(request_type=dict) + + +def test_create_metadata_store_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_store), + '__call__') as call: + client.create_metadata_store() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateMetadataStoreRequest() + +@pytest.mark.asyncio +async def test_create_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateMetadataStoreRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateMetadataStoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_metadata_store_async_from_dict(): + await test_create_metadata_store_async(request_type=dict) + + +def test_create_metadata_store_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateMetadataStoreRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_store), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_metadata_store_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateMetadataStoreRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_store), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_metadata_store_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_metadata_store( + parent='parent_value', + metadata_store=gca_metadata_store.MetadataStore(name='name_value'), + metadata_store_id='metadata_store_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].metadata_store == gca_metadata_store.MetadataStore(name='name_value') + + assert args[0].metadata_store_id == 'metadata_store_id_value' + + +def test_create_metadata_store_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_metadata_store( + metadata_service.CreateMetadataStoreRequest(), + parent='parent_value', + metadata_store=gca_metadata_store.MetadataStore(name='name_value'), + metadata_store_id='metadata_store_id_value', + ) + + +@pytest.mark.asyncio +async def test_create_metadata_store_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_metadata_store( + parent='parent_value', + metadata_store=gca_metadata_store.MetadataStore(name='name_value'), + metadata_store_id='metadata_store_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].metadata_store == gca_metadata_store.MetadataStore(name='name_value') + + assert args[0].metadata_store_id == 'metadata_store_id_value' + + +@pytest.mark.asyncio +async def test_create_metadata_store_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_metadata_store( + metadata_service.CreateMetadataStoreRequest(), + parent='parent_value', + metadata_store=gca_metadata_store.MetadataStore(name='name_value'), + metadata_store_id='metadata_store_id_value', + ) + + +def test_get_metadata_store(transport: str = 'grpc', request_type=metadata_service.GetMetadataStoreRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_store.MetadataStore( + name='name_value', + + ) + + response = client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetMetadataStoreRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, metadata_store.MetadataStore) + + assert response.name == 'name_value' + + +def test_get_metadata_store_from_dict(): + test_get_metadata_store(request_type=dict) + + +def test_get_metadata_store_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_store), + '__call__') as call: + client.get_metadata_store() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetMetadataStoreRequest() + +@pytest.mark.asyncio +async def test_get_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetMetadataStoreRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore( + name='name_value', + )) + + response = await client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetMetadataStoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_store.MetadataStore) + + assert response.name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_metadata_store_async_from_dict(): + await test_get_metadata_store_async(request_type=dict) + + +def test_get_metadata_store_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetMetadataStoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_store), + '__call__') as call: + call.return_value = metadata_store.MetadataStore() + + client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_metadata_store_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetMetadataStoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_store), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore()) + + await client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_metadata_store_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_store.MetadataStore() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_metadata_store( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_metadata_store_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_metadata_store( + metadata_service.GetMetadataStoreRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_metadata_store_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_store.MetadataStore() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_metadata_store( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_metadata_store_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_metadata_store( + metadata_service.GetMetadataStoreRequest(), + name='name_value', + ) + + +def test_list_metadata_stores(transport: str = 'grpc', request_type=metadata_service.ListMetadataStoresRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListMetadataStoresResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListMetadataStoresRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListMetadataStoresPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_metadata_stores_from_dict(): + test_list_metadata_stores(request_type=dict) + + +def test_list_metadata_stores_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + client.list_metadata_stores() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListMetadataStoresRequest() + +@pytest.mark.asyncio +async def test_list_metadata_stores_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListMetadataStoresRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListMetadataStoresRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListMetadataStoresAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_metadata_stores_async_from_dict(): + await test_list_metadata_stores_async(request_type=dict) + + +def test_list_metadata_stores_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListMetadataStoresRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + call.return_value = metadata_service.ListMetadataStoresResponse() + + client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_metadata_stores_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListMetadataStoresRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse()) + + await client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_metadata_stores_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListMetadataStoresResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_metadata_stores( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_metadata_stores_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_metadata_stores( + metadata_service.ListMetadataStoresRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_metadata_stores_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListMetadataStoresResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_metadata_stores( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_metadata_stores_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_metadata_stores( + metadata_service.ListMetadataStoresRequest(), + parent='parent_value', + ) + + +def test_list_metadata_stores_pager(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[], + next_page_token='def', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_metadata_stores(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, metadata_store.MetadataStore) + for i in results) + +def test_list_metadata_stores_pages(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[], + next_page_token='def', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + ), + RuntimeError, + ) + pages = list(client.list_metadata_stores(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_metadata_stores_async_pager(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[], + next_page_token='def', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_metadata_stores(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, metadata_store.MetadataStore) + for i in responses) + +@pytest.mark.asyncio +async def test_list_metadata_stores_async_pages(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_stores), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[], + next_page_token='def', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataStoresResponse( + metadata_stores=[ + metadata_store.MetadataStore(), + metadata_store.MetadataStore(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_metadata_stores(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_metadata_store(transport: str = 'grpc', request_type=metadata_service.DeleteMetadataStoreRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.DeleteMetadataStoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_metadata_store_from_dict(): + test_delete_metadata_store(request_type=dict) + + +def test_delete_metadata_store_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_metadata_store), + '__call__') as call: + client.delete_metadata_store() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.DeleteMetadataStoreRequest() + +@pytest.mark.asyncio +async def test_delete_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.DeleteMetadataStoreRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.DeleteMetadataStoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_metadata_store_async_from_dict(): + await test_delete_metadata_store_async(request_type=dict) + + +def test_delete_metadata_store_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.DeleteMetadataStoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_metadata_store), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_metadata_store_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.DeleteMetadataStoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_metadata_store), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_metadata_store_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_metadata_store( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_metadata_store_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_metadata_store( + metadata_service.DeleteMetadataStoreRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_metadata_store_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_metadata_store), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_metadata_store( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_metadata_store_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_metadata_store( + metadata_service.DeleteMetadataStoreRequest(), + name='name_value', + ) + + +def test_create_artifact(transport: str = 'grpc', request_type=metadata_service.CreateArtifactRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_artifact.Artifact( + name='name_value', + + display_name='display_name_value', + + uri='uri_value', + + etag='etag_value', + + state=gca_artifact.Artifact.State.PENDING, + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateArtifactRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_artifact.Artifact) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.uri == 'uri_value' + + assert response.etag == 'etag_value' + + assert response.state == gca_artifact.Artifact.State.PENDING + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_create_artifact_from_dict(): + test_create_artifact(request_type=dict) + + +def test_create_artifact_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_artifact), + '__call__') as call: + client.create_artifact() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateArtifactRequest() + +@pytest.mark.asyncio +async def test_create_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateArtifactRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact( + name='name_value', + display_name='display_name_value', + uri='uri_value', + etag='etag_value', + state=gca_artifact.Artifact.State.PENDING, + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateArtifactRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_artifact.Artifact) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.uri == 'uri_value' + + assert response.etag == 'etag_value' + + assert response.state == gca_artifact.Artifact.State.PENDING + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_create_artifact_async_from_dict(): + await test_create_artifact_async(request_type=dict) + + +def test_create_artifact_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateArtifactRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_artifact), + '__call__') as call: + call.return_value = gca_artifact.Artifact() + + client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_artifact_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateArtifactRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_artifact), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + + await client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_artifact_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_artifact.Artifact() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_artifact( + parent='parent_value', + artifact=gca_artifact.Artifact(name='name_value'), + artifact_id='artifact_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].artifact == gca_artifact.Artifact(name='name_value') + + assert args[0].artifact_id == 'artifact_id_value' + + +def test_create_artifact_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_artifact( + metadata_service.CreateArtifactRequest(), + parent='parent_value', + artifact=gca_artifact.Artifact(name='name_value'), + artifact_id='artifact_id_value', + ) + + +@pytest.mark.asyncio +async def test_create_artifact_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_artifact.Artifact() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_artifact( + parent='parent_value', + artifact=gca_artifact.Artifact(name='name_value'), + artifact_id='artifact_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].artifact == gca_artifact.Artifact(name='name_value') + + assert args[0].artifact_id == 'artifact_id_value' + + +@pytest.mark.asyncio +async def test_create_artifact_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_artifact( + metadata_service.CreateArtifactRequest(), + parent='parent_value', + artifact=gca_artifact.Artifact(name='name_value'), + artifact_id='artifact_id_value', + ) + + +def test_get_artifact(transport: str = 'grpc', request_type=metadata_service.GetArtifactRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = artifact.Artifact( + name='name_value', + + display_name='display_name_value', + + uri='uri_value', + + etag='etag_value', + + state=artifact.Artifact.State.PENDING, + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetArtifactRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, artifact.Artifact) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.uri == 'uri_value' + + assert response.etag == 'etag_value' + + assert response.state == artifact.Artifact.State.PENDING + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_get_artifact_from_dict(): + test_get_artifact(request_type=dict) + + +def test_get_artifact_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_artifact), + '__call__') as call: + client.get_artifact() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetArtifactRequest() + +@pytest.mark.asyncio +async def test_get_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetArtifactRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact( + name='name_value', + display_name='display_name_value', + uri='uri_value', + etag='etag_value', + state=artifact.Artifact.State.PENDING, + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetArtifactRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, artifact.Artifact) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.uri == 'uri_value' + + assert response.etag == 'etag_value' + + assert response.state == artifact.Artifact.State.PENDING + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_get_artifact_async_from_dict(): + await test_get_artifact_async(request_type=dict) + + +def test_get_artifact_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetArtifactRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_artifact), + '__call__') as call: + call.return_value = artifact.Artifact() + + client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_artifact_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetArtifactRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_artifact), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact()) + + await client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_artifact_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = artifact.Artifact() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_artifact( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_artifact_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_artifact( + metadata_service.GetArtifactRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_artifact_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = artifact.Artifact() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_artifact( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_artifact_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_artifact( + metadata_service.GetArtifactRequest(), + name='name_value', + ) + + +def test_list_artifacts(transport: str = 'grpc', request_type=metadata_service.ListArtifactsRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListArtifactsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListArtifactsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListArtifactsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_artifacts_from_dict(): + test_list_artifacts(request_type=dict) + + +def test_list_artifacts_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + client.list_artifacts() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListArtifactsRequest() + +@pytest.mark.asyncio +async def test_list_artifacts_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListArtifactsRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListArtifactsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListArtifactsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_artifacts_async_from_dict(): + await test_list_artifacts_async(request_type=dict) + + +def test_list_artifacts_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListArtifactsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + call.return_value = metadata_service.ListArtifactsResponse() + + client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_artifacts_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListArtifactsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse()) + + await client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_artifacts_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListArtifactsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_artifacts( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_artifacts_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_artifacts( + metadata_service.ListArtifactsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_artifacts_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListArtifactsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_artifacts( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_artifacts_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_artifacts( + metadata_service.ListArtifactsRequest(), + parent='parent_value', + ) + + +def test_list_artifacts_pager(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + artifact.Artifact(), + ], + next_page_token='abc', + ), + metadata_service.ListArtifactsResponse( + artifacts=[], + next_page_token='def', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + ], + next_page_token='ghi', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_artifacts(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, artifact.Artifact) + for i in results) + +def test_list_artifacts_pages(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + artifact.Artifact(), + ], + next_page_token='abc', + ), + metadata_service.ListArtifactsResponse( + artifacts=[], + next_page_token='def', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + ], + next_page_token='ghi', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + ], + ), + RuntimeError, + ) + pages = list(client.list_artifacts(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_artifacts_async_pager(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + artifact.Artifact(), + ], + next_page_token='abc', + ), + metadata_service.ListArtifactsResponse( + artifacts=[], + next_page_token='def', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + ], + next_page_token='ghi', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_artifacts(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, artifact.Artifact) + for i in responses) + +@pytest.mark.asyncio +async def test_list_artifacts_async_pages(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_artifacts), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + artifact.Artifact(), + ], + next_page_token='abc', + ), + metadata_service.ListArtifactsResponse( + artifacts=[], + next_page_token='def', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + ], + next_page_token='ghi', + ), + metadata_service.ListArtifactsResponse( + artifacts=[ + artifact.Artifact(), + artifact.Artifact(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_artifacts(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_artifact(transport: str = 'grpc', request_type=metadata_service.UpdateArtifactRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_artifact.Artifact( + name='name_value', + + display_name='display_name_value', + + uri='uri_value', + + etag='etag_value', + + state=gca_artifact.Artifact.State.PENDING, + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateArtifactRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_artifact.Artifact) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.uri == 'uri_value' + + assert response.etag == 'etag_value' + + assert response.state == gca_artifact.Artifact.State.PENDING + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_update_artifact_from_dict(): + test_update_artifact(request_type=dict) + + +def test_update_artifact_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_artifact), + '__call__') as call: + client.update_artifact() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateArtifactRequest() + +@pytest.mark.asyncio +async def test_update_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateArtifactRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact( + name='name_value', + display_name='display_name_value', + uri='uri_value', + etag='etag_value', + state=gca_artifact.Artifact.State.PENDING, + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateArtifactRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_artifact.Artifact) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.uri == 'uri_value' + + assert response.etag == 'etag_value' + + assert response.state == gca_artifact.Artifact.State.PENDING + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_update_artifact_async_from_dict(): + await test_update_artifact_async(request_type=dict) + + +def test_update_artifact_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateArtifactRequest() + request.artifact.name = 'artifact.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_artifact), + '__call__') as call: + call.return_value = gca_artifact.Artifact() + + client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'artifact.name=artifact.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_artifact_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateArtifactRequest() + request.artifact.name = 'artifact.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_artifact), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + + await client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'artifact.name=artifact.name/value', + ) in kw['metadata'] + + +def test_update_artifact_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_artifact.Artifact() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_artifact( + artifact=gca_artifact.Artifact(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].artifact == gca_artifact.Artifact(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_artifact_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_artifact( + metadata_service.UpdateArtifactRequest(), + artifact=gca_artifact.Artifact(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_artifact_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_artifact), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_artifact.Artifact() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_artifact( + artifact=gca_artifact.Artifact(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].artifact == gca_artifact.Artifact(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_artifact_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_artifact( + metadata_service.UpdateArtifactRequest(), + artifact=gca_artifact.Artifact(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_create_context(transport: str = 'grpc', request_type=metadata_service.CreateContextRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context( + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + + parent_contexts=['parent_contexts_value'], + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateContextRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_context.Context) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.parent_contexts == ['parent_contexts_value'] + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_create_context_from_dict(): + test_create_context(request_type=dict) + + +def test_create_context_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_context), + '__call__') as call: + client.create_context() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateContextRequest() + +@pytest.mark.asyncio +async def test_create_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateContextRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context( + name='name_value', + display_name='display_name_value', + etag='etag_value', + parent_contexts=['parent_contexts_value'], + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_context.Context) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.parent_contexts == ['parent_contexts_value'] + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_create_context_async_from_dict(): + await test_create_context_async(request_type=dict) + + +def test_create_context_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateContextRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_context), + '__call__') as call: + call.return_value = gca_context.Context() + + client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_context_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateContextRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_context), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + + await client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_context_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_context( + parent='parent_value', + context=gca_context.Context(name='name_value'), + context_id='context_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].context == gca_context.Context(name='name_value') + + assert args[0].context_id == 'context_id_value' + + +def test_create_context_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_context( + metadata_service.CreateContextRequest(), + parent='parent_value', + context=gca_context.Context(name='name_value'), + context_id='context_id_value', + ) + + +@pytest.mark.asyncio +async def test_create_context_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_context( + parent='parent_value', + context=gca_context.Context(name='name_value'), + context_id='context_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].context == gca_context.Context(name='name_value') + + assert args[0].context_id == 'context_id_value' + + +@pytest.mark.asyncio +async def test_create_context_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_context( + metadata_service.CreateContextRequest(), + parent='parent_value', + context=gca_context.Context(name='name_value'), + context_id='context_id_value', + ) + + +def test_get_context(transport: str = 'grpc', request_type=metadata_service.GetContextRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = context.Context( + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + + parent_contexts=['parent_contexts_value'], + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetContextRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, context.Context) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.parent_contexts == ['parent_contexts_value'] + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_get_context_from_dict(): + test_get_context(request_type=dict) + + +def test_get_context_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_context), + '__call__') as call: + client.get_context() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetContextRequest() + +@pytest.mark.asyncio +async def test_get_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetContextRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context( + name='name_value', + display_name='display_name_value', + etag='etag_value', + parent_contexts=['parent_contexts_value'], + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, context.Context) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.parent_contexts == ['parent_contexts_value'] + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_get_context_async_from_dict(): + await test_get_context_async(request_type=dict) + + +def test_get_context_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetContextRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_context), + '__call__') as call: + call.return_value = context.Context() + + client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_context_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetContextRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_context), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) + + await client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_context_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = context.Context() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_context( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_context_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_context( + metadata_service.GetContextRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_context_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = context.Context() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_context( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_context_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_context( + metadata_service.GetContextRequest(), + name='name_value', + ) + + +def test_list_contexts(transport: str = 'grpc', request_type=metadata_service.ListContextsRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListContextsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListContextsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListContextsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_contexts_from_dict(): + test_list_contexts(request_type=dict) + + +def test_list_contexts_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + client.list_contexts() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListContextsRequest() + +@pytest.mark.asyncio +async def test_list_contexts_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListContextsRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListContextsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListContextsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_contexts_async_from_dict(): + await test_list_contexts_async(request_type=dict) + + +def test_list_contexts_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListContextsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + call.return_value = metadata_service.ListContextsResponse() + + client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_contexts_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListContextsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse()) + + await client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_contexts_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListContextsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_contexts( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_contexts_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_contexts( + metadata_service.ListContextsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_contexts_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListContextsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_contexts( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_contexts_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_contexts( + metadata_service.ListContextsRequest(), + parent='parent_value', + ) + + +def test_list_contexts_pager(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + context.Context(), + ], + next_page_token='abc', + ), + metadata_service.ListContextsResponse( + contexts=[], + next_page_token='def', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + ], + next_page_token='ghi', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_contexts(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, context.Context) + for i in results) + +def test_list_contexts_pages(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + context.Context(), + ], + next_page_token='abc', + ), + metadata_service.ListContextsResponse( + contexts=[], + next_page_token='def', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + ], + next_page_token='ghi', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + ], + ), + RuntimeError, + ) + pages = list(client.list_contexts(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_contexts_async_pager(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + context.Context(), + ], + next_page_token='abc', + ), + metadata_service.ListContextsResponse( + contexts=[], + next_page_token='def', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + ], + next_page_token='ghi', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_contexts(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, context.Context) + for i in responses) + +@pytest.mark.asyncio +async def test_list_contexts_async_pages(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + context.Context(), + ], + next_page_token='abc', + ), + metadata_service.ListContextsResponse( + contexts=[], + next_page_token='def', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + ], + next_page_token='ghi', + ), + metadata_service.ListContextsResponse( + contexts=[ + context.Context(), + context.Context(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_contexts(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_context(transport: str = 'grpc', request_type=metadata_service.UpdateContextRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context( + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + + parent_contexts=['parent_contexts_value'], + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateContextRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_context.Context) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.parent_contexts == ['parent_contexts_value'] + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_update_context_from_dict(): + test_update_context(request_type=dict) + + +def test_update_context_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_context), + '__call__') as call: + client.update_context() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateContextRequest() + +@pytest.mark.asyncio +async def test_update_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateContextRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context( + name='name_value', + display_name='display_name_value', + etag='etag_value', + parent_contexts=['parent_contexts_value'], + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_context.Context) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.parent_contexts == ['parent_contexts_value'] + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_update_context_async_from_dict(): + await test_update_context_async(request_type=dict) + + +def test_update_context_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateContextRequest() + request.context.name = 'context.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_context), + '__call__') as call: + call.return_value = gca_context.Context() + + client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context.name=context.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_context_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateContextRequest() + request.context.name = 'context.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_context), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + + await client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context.name=context.name/value', + ) in kw['metadata'] + + +def test_update_context_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_context( + context=gca_context.Context(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].context == gca_context.Context(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_context_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_context( + metadata_service.UpdateContextRequest(), + context=gca_context.Context(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_context_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_context( + context=gca_context.Context(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].context == gca_context.Context(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_context_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_context( + metadata_service.UpdateContextRequest(), + context=gca_context.Context(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_context(transport: str = 'grpc', request_type=metadata_service.DeleteContextRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.DeleteContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_context_from_dict(): + test_delete_context(request_type=dict) + + +def test_delete_context_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_context), + '__call__') as call: + client.delete_context() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.DeleteContextRequest() + +@pytest.mark.asyncio +async def test_delete_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.DeleteContextRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.DeleteContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_context_async_from_dict(): + await test_delete_context_async(request_type=dict) + + +def test_delete_context_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.DeleteContextRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_context), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_context_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.DeleteContextRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_context), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_context_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_context( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_context_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_context( + metadata_service.DeleteContextRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_context_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_context), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_context( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_context_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_context( + metadata_service.DeleteContextRequest(), + name='name_value', + ) + + +def test_add_context_artifacts_and_executions(transport: str = 'grpc', request_type=metadata_service.AddContextArtifactsAndExecutionsRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse( + ) + + response = client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, metadata_service.AddContextArtifactsAndExecutionsResponse) + + +def test_add_context_artifacts_and_executions_from_dict(): + test_add_context_artifacts_and_executions(request_type=dict) + + +def test_add_context_artifacts_and_executions_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), + '__call__') as call: + client.add_context_artifacts_and_executions() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddContextArtifactsAndExecutionsRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse( + )) + + response = await client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_service.AddContextArtifactsAndExecutionsResponse) + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_async_from_dict(): + await test_add_context_artifacts_and_executions_async(request_type=dict) + + +def test_add_context_artifacts_and_executions_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddContextArtifactsAndExecutionsRequest() + request.context = 'context/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), + '__call__') as call: + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + + client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context=context/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddContextArtifactsAndExecutionsRequest() + request.context = 'context/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse()) + + await client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context=context/value', + ) in kw['metadata'] + + +def test_add_context_artifacts_and_executions_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.add_context_artifacts_and_executions( + context='context_value', + artifacts=['artifacts_value'], + executions=['executions_value'], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].context == 'context_value' + + assert args[0].artifacts == ['artifacts_value'] + + assert args[0].executions == ['executions_value'] + + +def test_add_context_artifacts_and_executions_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.add_context_artifacts_and_executions( + metadata_service.AddContextArtifactsAndExecutionsRequest(), + context='context_value', + artifacts=['artifacts_value'], + executions=['executions_value'], + ) + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.add_context_artifacts_and_executions( + context='context_value', + artifacts=['artifacts_value'], + executions=['executions_value'], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].context == 'context_value' + + assert args[0].artifacts == ['artifacts_value'] + + assert args[0].executions == ['executions_value'] + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.add_context_artifacts_and_executions( + metadata_service.AddContextArtifactsAndExecutionsRequest(), + context='context_value', + artifacts=['artifacts_value'], + executions=['executions_value'], + ) + + +def test_add_context_children(transport: str = 'grpc', request_type=metadata_service.AddContextChildrenRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextChildrenResponse( + ) + + response = client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddContextChildrenRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, metadata_service.AddContextChildrenResponse) + + +def test_add_context_children_from_dict(): + test_add_context_children(request_type=dict) + + +def test_add_context_children_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), + '__call__') as call: + client.add_context_children() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddContextChildrenRequest() + +@pytest.mark.asyncio +async def test_add_context_children_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddContextChildrenRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse( + )) + + response = await client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddContextChildrenRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_service.AddContextChildrenResponse) + + +@pytest.mark.asyncio +async def test_add_context_children_async_from_dict(): + await test_add_context_children_async(request_type=dict) + + +def test_add_context_children_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddContextChildrenRequest() + request.context = 'context/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), + '__call__') as call: + call.return_value = metadata_service.AddContextChildrenResponse() + + client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context=context/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_add_context_children_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddContextChildrenRequest() + request.context = 'context/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse()) + + await client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context=context/value', + ) in kw['metadata'] + + +def test_add_context_children_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextChildrenResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.add_context_children( + context='context_value', + child_contexts=['child_contexts_value'], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].context == 'context_value' + + assert args[0].child_contexts == ['child_contexts_value'] + + +def test_add_context_children_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.add_context_children( + metadata_service.AddContextChildrenRequest(), + context='context_value', + child_contexts=['child_contexts_value'], + ) + + +@pytest.mark.asyncio +async def test_add_context_children_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextChildrenResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.add_context_children( + context='context_value', + child_contexts=['child_contexts_value'], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].context == 'context_value' + + assert args[0].child_contexts == ['child_contexts_value'] + + +@pytest.mark.asyncio +async def test_add_context_children_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.add_context_children( + metadata_service.AddContextChildrenRequest(), + context='context_value', + child_contexts=['child_contexts_value'], + ) + + +def test_query_context_lineage_subgraph(transport: str = 'grpc', request_type=metadata_service.QueryContextLineageSubgraphRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_context_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph( + ) + + response = client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, lineage_subgraph.LineageSubgraph) + + +def test_query_context_lineage_subgraph_from_dict(): + test_query_context_lineage_subgraph(request_type=dict) + + +def test_query_context_lineage_subgraph_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_context_lineage_subgraph), + '__call__') as call: + client.query_context_lineage_subgraph() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + +@pytest.mark.asyncio +async def test_query_context_lineage_subgraph_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryContextLineageSubgraphRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_context_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( + )) + + response = await client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, lineage_subgraph.LineageSubgraph) + + +@pytest.mark.asyncio +async def test_query_context_lineage_subgraph_async_from_dict(): + await test_query_context_lineage_subgraph_async(request_type=dict) + + +def test_query_context_lineage_subgraph_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.QueryContextLineageSubgraphRequest() + request.context = 'context/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_context_lineage_subgraph), + '__call__') as call: + call.return_value = lineage_subgraph.LineageSubgraph() + + client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context=context/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_query_context_lineage_subgraph_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.QueryContextLineageSubgraphRequest() + request.context = 'context/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_context_lineage_subgraph), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + + await client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'context=context/value', + ) in kw['metadata'] + + +def test_query_context_lineage_subgraph_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_context_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.query_context_lineage_subgraph( + context='context_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].context == 'context_value' + + +def test_query_context_lineage_subgraph_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.query_context_lineage_subgraph( + metadata_service.QueryContextLineageSubgraphRequest(), + context='context_value', + ) + + +@pytest.mark.asyncio +async def test_query_context_lineage_subgraph_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_context_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.query_context_lineage_subgraph( + context='context_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].context == 'context_value' + + +@pytest.mark.asyncio +async def test_query_context_lineage_subgraph_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.query_context_lineage_subgraph( + metadata_service.QueryContextLineageSubgraphRequest(), + context='context_value', + ) + + +def test_create_execution(transport: str = 'grpc', request_type=metadata_service.CreateExecutionRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_execution.Execution( + name='name_value', + + display_name='display_name_value', + + state=gca_execution.Execution.State.NEW, + + etag='etag_value', + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateExecutionRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_execution.Execution) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == gca_execution.Execution.State.NEW + + assert response.etag == 'etag_value' + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_create_execution_from_dict(): + test_create_execution(request_type=dict) + + +def test_create_execution_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_execution), + '__call__') as call: + client.create_execution() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateExecutionRequest() + +@pytest.mark.asyncio +async def test_create_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateExecutionRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution( + name='name_value', + display_name='display_name_value', + state=gca_execution.Execution.State.NEW, + etag='etag_value', + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateExecutionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_execution.Execution) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == gca_execution.Execution.State.NEW + + assert response.etag == 'etag_value' + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_create_execution_async_from_dict(): + await test_create_execution_async(request_type=dict) + + +def test_create_execution_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateExecutionRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_execution), + '__call__') as call: + call.return_value = gca_execution.Execution() + + client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_execution_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateExecutionRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_execution), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + + await client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_execution_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_execution.Execution() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_execution( + parent='parent_value', + execution=gca_execution.Execution(name='name_value'), + execution_id='execution_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].execution == gca_execution.Execution(name='name_value') + + assert args[0].execution_id == 'execution_id_value' + + +def test_create_execution_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_execution( + metadata_service.CreateExecutionRequest(), + parent='parent_value', + execution=gca_execution.Execution(name='name_value'), + execution_id='execution_id_value', + ) + + +@pytest.mark.asyncio +async def test_create_execution_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_execution.Execution() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_execution( + parent='parent_value', + execution=gca_execution.Execution(name='name_value'), + execution_id='execution_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].execution == gca_execution.Execution(name='name_value') + + assert args[0].execution_id == 'execution_id_value' + + +@pytest.mark.asyncio +async def test_create_execution_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_execution( + metadata_service.CreateExecutionRequest(), + parent='parent_value', + execution=gca_execution.Execution(name='name_value'), + execution_id='execution_id_value', + ) + + +def test_get_execution(transport: str = 'grpc', request_type=metadata_service.GetExecutionRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = execution.Execution( + name='name_value', + + display_name='display_name_value', + + state=execution.Execution.State.NEW, + + etag='etag_value', + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetExecutionRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, execution.Execution) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == execution.Execution.State.NEW + + assert response.etag == 'etag_value' + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_get_execution_from_dict(): + test_get_execution(request_type=dict) + + +def test_get_execution_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_execution), + '__call__') as call: + client.get_execution() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetExecutionRequest() + +@pytest.mark.asyncio +async def test_get_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetExecutionRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution( + name='name_value', + display_name='display_name_value', + state=execution.Execution.State.NEW, + etag='etag_value', + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetExecutionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, execution.Execution) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == execution.Execution.State.NEW + + assert response.etag == 'etag_value' + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_get_execution_async_from_dict(): + await test_get_execution_async(request_type=dict) + + +def test_get_execution_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetExecutionRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_execution), + '__call__') as call: + call.return_value = execution.Execution() + + client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_execution_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetExecutionRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_execution), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) + + await client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_execution_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = execution.Execution() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_execution( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_execution_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_execution( + metadata_service.GetExecutionRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_execution_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = execution.Execution() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_execution( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_execution_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_execution( + metadata_service.GetExecutionRequest(), + name='name_value', + ) + + +def test_list_executions(transport: str = 'grpc', request_type=metadata_service.ListExecutionsRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListExecutionsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListExecutionsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListExecutionsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_executions_from_dict(): + test_list_executions(request_type=dict) + + +def test_list_executions_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + client.list_executions() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListExecutionsRequest() + +@pytest.mark.asyncio +async def test_list_executions_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListExecutionsRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListExecutionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListExecutionsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_executions_async_from_dict(): + await test_list_executions_async(request_type=dict) + + +def test_list_executions_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListExecutionsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + call.return_value = metadata_service.ListExecutionsResponse() + + client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_executions_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListExecutionsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse()) + + await client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_executions_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListExecutionsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_executions( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_executions_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_executions( + metadata_service.ListExecutionsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_executions_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListExecutionsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_executions( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_executions_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_executions( + metadata_service.ListExecutionsRequest(), + parent='parent_value', + ) + + +def test_list_executions_pager(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token='abc', + ), + metadata_service.ListExecutionsResponse( + executions=[], + next_page_token='def', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + ], + next_page_token='ghi', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_executions(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, execution.Execution) + for i in results) + +def test_list_executions_pages(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token='abc', + ), + metadata_service.ListExecutionsResponse( + executions=[], + next_page_token='def', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + ], + next_page_token='ghi', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + ], + ), + RuntimeError, + ) + pages = list(client.list_executions(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_executions_async_pager(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token='abc', + ), + metadata_service.ListExecutionsResponse( + executions=[], + next_page_token='def', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + ], + next_page_token='ghi', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_executions(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, execution.Execution) + for i in responses) + +@pytest.mark.asyncio +async def test_list_executions_async_pages(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token='abc', + ), + metadata_service.ListExecutionsResponse( + executions=[], + next_page_token='def', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + ], + next_page_token='ghi', + ), + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_executions(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_execution(transport: str = 'grpc', request_type=metadata_service.UpdateExecutionRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_execution.Execution( + name='name_value', + + display_name='display_name_value', + + state=gca_execution.Execution.State.NEW, + + etag='etag_value', + + schema_title='schema_title_value', + + schema_version='schema_version_value', + + description='description_value', + + ) + + response = client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateExecutionRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_execution.Execution) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == gca_execution.Execution.State.NEW + + assert response.etag == 'etag_value' + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +def test_update_execution_from_dict(): + test_update_execution(request_type=dict) + + +def test_update_execution_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_execution), + '__call__') as call: + client.update_execution() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateExecutionRequest() + +@pytest.mark.asyncio +async def test_update_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateExecutionRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution( + name='name_value', + display_name='display_name_value', + state=gca_execution.Execution.State.NEW, + etag='etag_value', + schema_title='schema_title_value', + schema_version='schema_version_value', + description='description_value', + )) + + response = await client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.UpdateExecutionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_execution.Execution) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == gca_execution.Execution.State.NEW + + assert response.etag == 'etag_value' + + assert response.schema_title == 'schema_title_value' + + assert response.schema_version == 'schema_version_value' + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_update_execution_async_from_dict(): + await test_update_execution_async(request_type=dict) + + +def test_update_execution_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateExecutionRequest() + request.execution.name = 'execution.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_execution), + '__call__') as call: + call.return_value = gca_execution.Execution() + + client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'execution.name=execution.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_execution_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateExecutionRequest() + request.execution.name = 'execution.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_execution), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + + await client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'execution.name=execution.name/value', + ) in kw['metadata'] + + +def test_update_execution_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_execution.Execution() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_execution( + execution=gca_execution.Execution(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].execution == gca_execution.Execution(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_execution_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_execution( + metadata_service.UpdateExecutionRequest(), + execution=gca_execution.Execution(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_execution_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_execution), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_execution.Execution() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_execution( + execution=gca_execution.Execution(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].execution == gca_execution.Execution(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_execution_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_execution( + metadata_service.UpdateExecutionRequest(), + execution=gca_execution.Execution(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_add_execution_events(transport: str = 'grpc', request_type=metadata_service.AddExecutionEventsRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_execution_events), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddExecutionEventsResponse( + ) + + response = client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddExecutionEventsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, metadata_service.AddExecutionEventsResponse) + + +def test_add_execution_events_from_dict(): + test_add_execution_events(request_type=dict) + + +def test_add_execution_events_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_execution_events), + '__call__') as call: + client.add_execution_events() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddExecutionEventsRequest() + +@pytest.mark.asyncio +async def test_add_execution_events_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddExecutionEventsRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_execution_events), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse( + )) + + response = await client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.AddExecutionEventsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_service.AddExecutionEventsResponse) + + +@pytest.mark.asyncio +async def test_add_execution_events_async_from_dict(): + await test_add_execution_events_async(request_type=dict) + + +def test_add_execution_events_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddExecutionEventsRequest() + request.execution = 'execution/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_execution_events), + '__call__') as call: + call.return_value = metadata_service.AddExecutionEventsResponse() + + client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'execution=execution/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_add_execution_events_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddExecutionEventsRequest() + request.execution = 'execution/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_execution_events), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse()) + + await client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'execution=execution/value', + ) in kw['metadata'] + + +def test_add_execution_events_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_execution_events), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddExecutionEventsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.add_execution_events( + execution='execution_value', + events=[event.Event(artifact='artifact_value')], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].execution == 'execution_value' + + assert args[0].events == [event.Event(artifact='artifact_value')] + + +def test_add_execution_events_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.add_execution_events( + metadata_service.AddExecutionEventsRequest(), + execution='execution_value', + events=[event.Event(artifact='artifact_value')], + ) + + +@pytest.mark.asyncio +async def test_add_execution_events_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_execution_events), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddExecutionEventsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.add_execution_events( + execution='execution_value', + events=[event.Event(artifact='artifact_value')], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].execution == 'execution_value' + + assert args[0].events == [event.Event(artifact='artifact_value')] + + +@pytest.mark.asyncio +async def test_add_execution_events_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.add_execution_events( + metadata_service.AddExecutionEventsRequest(), + execution='execution_value', + events=[event.Event(artifact='artifact_value')], + ) + + +def test_query_execution_inputs_and_outputs(transport: str = 'grpc', request_type=metadata_service.QueryExecutionInputsAndOutputsRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph( + ) + + response = client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, lineage_subgraph.LineageSubgraph) + + +def test_query_execution_inputs_and_outputs_from_dict(): + test_query_execution_inputs_and_outputs(request_type=dict) + + +def test_query_execution_inputs_and_outputs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), + '__call__') as call: + client.query_execution_inputs_and_outputs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + +@pytest.mark.asyncio +async def test_query_execution_inputs_and_outputs_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryExecutionInputsAndOutputsRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( + )) + + response = await client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, lineage_subgraph.LineageSubgraph) + + +@pytest.mark.asyncio +async def test_query_execution_inputs_and_outputs_async_from_dict(): + await test_query_execution_inputs_and_outputs_async(request_type=dict) + + +def test_query_execution_inputs_and_outputs_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.QueryExecutionInputsAndOutputsRequest() + request.execution = 'execution/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), + '__call__') as call: + call.return_value = lineage_subgraph.LineageSubgraph() + + client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'execution=execution/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_query_execution_inputs_and_outputs_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.QueryExecutionInputsAndOutputsRequest() + request.execution = 'execution/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + + await client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'execution=execution/value', + ) in kw['metadata'] + + +def test_query_execution_inputs_and_outputs_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.query_execution_inputs_and_outputs( + execution='execution_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].execution == 'execution_value' + + +def test_query_execution_inputs_and_outputs_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.query_execution_inputs_and_outputs( + metadata_service.QueryExecutionInputsAndOutputsRequest(), + execution='execution_value', + ) + + +@pytest.mark.asyncio +async def test_query_execution_inputs_and_outputs_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.query_execution_inputs_and_outputs( + execution='execution_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].execution == 'execution_value' + + +@pytest.mark.asyncio +async def test_query_execution_inputs_and_outputs_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.query_execution_inputs_and_outputs( + metadata_service.QueryExecutionInputsAndOutputsRequest(), + execution='execution_value', + ) + + +def test_create_metadata_schema(transport: str = 'grpc', request_type=metadata_service.CreateMetadataSchemaRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_metadata_schema.MetadataSchema( + name='name_value', + + schema_version='schema_version_value', + + schema='schema_value', + + schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + + description='description_value', + + ) + + response = client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateMetadataSchemaRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_metadata_schema.MetadataSchema) + + assert response.name == 'name_value' + + assert response.schema_version == 'schema_version_value' + + assert response.schema == 'schema_value' + + assert response.schema_type == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + + assert response.description == 'description_value' + + +def test_create_metadata_schema_from_dict(): + test_create_metadata_schema(request_type=dict) + + +def test_create_metadata_schema_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_schema), + '__call__') as call: + client.create_metadata_schema() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateMetadataSchemaRequest() + +@pytest.mark.asyncio +async def test_create_metadata_schema_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateMetadataSchemaRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema( + name='name_value', + schema_version='schema_version_value', + schema='schema_value', + schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + description='description_value', + )) + + response = await client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.CreateMetadataSchemaRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_metadata_schema.MetadataSchema) + + assert response.name == 'name_value' + + assert response.schema_version == 'schema_version_value' + + assert response.schema == 'schema_value' + + assert response.schema_type == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_create_metadata_schema_async_from_dict(): + await test_create_metadata_schema_async(request_type=dict) + + +def test_create_metadata_schema_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateMetadataSchemaRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_schema), + '__call__') as call: + call.return_value = gca_metadata_schema.MetadataSchema() + + client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_metadata_schema_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.CreateMetadataSchemaRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_schema), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema()) + + await client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_metadata_schema_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_metadata_schema.MetadataSchema() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_metadata_schema( + parent='parent_value', + metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), + metadata_schema_id='metadata_schema_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema(name='name_value') + + assert args[0].metadata_schema_id == 'metadata_schema_id_value' + + +def test_create_metadata_schema_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_metadata_schema( + metadata_service.CreateMetadataSchemaRequest(), + parent='parent_value', + metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), + metadata_schema_id='metadata_schema_id_value', + ) + + +@pytest.mark.asyncio +async def test_create_metadata_schema_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_metadata_schema.MetadataSchema() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_metadata_schema( + parent='parent_value', + metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), + metadata_schema_id='metadata_schema_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema(name='name_value') + + assert args[0].metadata_schema_id == 'metadata_schema_id_value' + + +@pytest.mark.asyncio +async def test_create_metadata_schema_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_metadata_schema( + metadata_service.CreateMetadataSchemaRequest(), + parent='parent_value', + metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), + metadata_schema_id='metadata_schema_id_value', + ) + + +def test_get_metadata_schema(transport: str = 'grpc', request_type=metadata_service.GetMetadataSchemaRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_schema.MetadataSchema( + name='name_value', + + schema_version='schema_version_value', + + schema='schema_value', + + schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + + description='description_value', + + ) + + response = client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetMetadataSchemaRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, metadata_schema.MetadataSchema) + + assert response.name == 'name_value' + + assert response.schema_version == 'schema_version_value' + + assert response.schema == 'schema_value' + + assert response.schema_type == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + + assert response.description == 'description_value' + + +def test_get_metadata_schema_from_dict(): + test_get_metadata_schema(request_type=dict) + + +def test_get_metadata_schema_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_schema), + '__call__') as call: + client.get_metadata_schema() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetMetadataSchemaRequest() + +@pytest.mark.asyncio +async def test_get_metadata_schema_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetMetadataSchemaRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema( + name='name_value', + schema_version='schema_version_value', + schema='schema_value', + schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + description='description_value', + )) + + response = await client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.GetMetadataSchemaRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_schema.MetadataSchema) + + assert response.name == 'name_value' + + assert response.schema_version == 'schema_version_value' + + assert response.schema == 'schema_value' + + assert response.schema_type == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + + assert response.description == 'description_value' + + +@pytest.mark.asyncio +async def test_get_metadata_schema_async_from_dict(): + await test_get_metadata_schema_async(request_type=dict) + + +def test_get_metadata_schema_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetMetadataSchemaRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_schema), + '__call__') as call: + call.return_value = metadata_schema.MetadataSchema() + + client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_metadata_schema_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.GetMetadataSchemaRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_schema), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema()) + + await client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_metadata_schema_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_schema.MetadataSchema() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_metadata_schema( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_metadata_schema_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_metadata_schema( + metadata_service.GetMetadataSchemaRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_metadata_schema_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_metadata_schema), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_schema.MetadataSchema() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_metadata_schema( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_metadata_schema_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_metadata_schema( + metadata_service.GetMetadataSchemaRequest(), + name='name_value', + ) + + +def test_list_metadata_schemas(transport: str = 'grpc', request_type=metadata_service.ListMetadataSchemasRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListMetadataSchemasResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListMetadataSchemasRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListMetadataSchemasPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_metadata_schemas_from_dict(): + test_list_metadata_schemas(request_type=dict) + + +def test_list_metadata_schemas_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + client.list_metadata_schemas() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListMetadataSchemasRequest() + +@pytest.mark.asyncio +async def test_list_metadata_schemas_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListMetadataSchemasRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.ListMetadataSchemasRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListMetadataSchemasAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_metadata_schemas_async_from_dict(): + await test_list_metadata_schemas_async(request_type=dict) + + +def test_list_metadata_schemas_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListMetadataSchemasRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + call.return_value = metadata_service.ListMetadataSchemasResponse() + + client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_metadata_schemas_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListMetadataSchemasRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse()) + + await client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_metadata_schemas_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListMetadataSchemasResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_metadata_schemas( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_metadata_schemas_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_metadata_schemas( + metadata_service.ListMetadataSchemasRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_metadata_schemas_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListMetadataSchemasResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_metadata_schemas( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_metadata_schemas_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_metadata_schemas( + metadata_service.ListMetadataSchemasRequest(), + parent='parent_value', + ) + + +def test_list_metadata_schemas_pager(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[], + next_page_token='def', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_metadata_schemas(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, metadata_schema.MetadataSchema) + for i in results) + +def test_list_metadata_schemas_pages(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[], + next_page_token='def', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + ), + RuntimeError, + ) + pages = list(client.list_metadata_schemas(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_metadata_schemas_async_pager(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[], + next_page_token='def', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_metadata_schemas(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, metadata_schema.MetadataSchema) + for i in responses) + +@pytest.mark.asyncio +async def test_list_metadata_schemas_async_pages(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_metadata_schemas), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + next_page_token='abc', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[], + next_page_token='def', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + ], + next_page_token='ghi', + ), + metadata_service.ListMetadataSchemasResponse( + metadata_schemas=[ + metadata_schema.MetadataSchema(), + metadata_schema.MetadataSchema(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_metadata_schemas(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.MetadataServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.MetadataServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MetadataServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.MetadataServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MetadataServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.MetadataServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = MetadataServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.MetadataServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.MetadataServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.MetadataServiceGrpcTransport, + ) + + +def test_metadata_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.MetadataServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_metadata_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.MetadataServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_metadata_store', + 'get_metadata_store', + 'list_metadata_stores', + 'delete_metadata_store', + 'create_artifact', + 'get_artifact', + 'list_artifacts', + 'update_artifact', + 'create_context', + 'get_context', + 'list_contexts', + 'update_context', + 'delete_context', + 'add_context_artifacts_and_executions', + 'add_context_children', + 'query_context_lineage_subgraph', + 'create_execution', + 'get_execution', + 'list_executions', + 'update_execution', + 'add_execution_events', + 'query_execution_inputs_and_outputs', + 'create_metadata_schema', + 'get_metadata_schema', + 'list_metadata_schemas', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_metadata_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MetadataServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_metadata_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MetadataServiceTransport() + adc.assert_called_once() + + +def test_metadata_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + MetadataServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_metadata_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.MetadataServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) +def test_metadata_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + + +def test_metadata_service_host_no_port(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_metadata_service_host_with_port(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_metadata_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.MetadataServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_metadata_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.MetadataServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) +def test_metadata_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) +def test_metadata_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_metadata_service_grpc_lro_client(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_metadata_service_grpc_lro_async_client(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_artifact_path(): + project = "squid" + location = "clam" + metadata_store = "whelk" + artifact = "octopus" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format(project=project, location=location, metadata_store=metadata_store, artifact=artifact, ) + actual = MetadataServiceClient.artifact_path(project, location, metadata_store, artifact) + assert expected == actual + + +def test_parse_artifact_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "artifact": "mussel", + + } + path = MetadataServiceClient.artifact_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_artifact_path(path) + assert expected == actual + +def test_context_path(): + project = "winkle" + location = "nautilus" + metadata_store = "scallop" + context = "abalone" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format(project=project, location=location, metadata_store=metadata_store, context=context, ) + actual = MetadataServiceClient.context_path(project, location, metadata_store, context) + assert expected == actual + + +def test_parse_context_path(): + expected = { + "project": "squid", + "location": "clam", + "metadata_store": "whelk", + "context": "octopus", + + } + path = MetadataServiceClient.context_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_context_path(path) + assert expected == actual + +def test_execution_path(): + project = "oyster" + location = "nudibranch" + metadata_store = "cuttlefish" + execution = "mussel" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format(project=project, location=location, metadata_store=metadata_store, execution=execution, ) + actual = MetadataServiceClient.execution_path(project, location, metadata_store, execution) + assert expected == actual + + +def test_parse_execution_path(): + expected = { + "project": "winkle", + "location": "nautilus", + "metadata_store": "scallop", + "execution": "abalone", + + } + path = MetadataServiceClient.execution_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_execution_path(path) + assert expected == actual + +def test_metadata_schema_path(): + project = "squid" + location = "clam" + metadata_store = "whelk" + metadata_schema = "octopus" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format(project=project, location=location, metadata_store=metadata_store, metadata_schema=metadata_schema, ) + actual = MetadataServiceClient.metadata_schema_path(project, location, metadata_store, metadata_schema) + assert expected == actual + + +def test_parse_metadata_schema_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "metadata_schema": "mussel", + + } + path = MetadataServiceClient.metadata_schema_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_metadata_schema_path(path) + assert expected == actual + +def test_metadata_store_path(): + project = "winkle" + location = "nautilus" + metadata_store = "scallop" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format(project=project, location=location, metadata_store=metadata_store, ) + actual = MetadataServiceClient.metadata_store_path(project, location, metadata_store) + assert expected == actual + + +def test_parse_metadata_store_path(): + expected = { + "project": "abalone", + "location": "squid", + "metadata_store": "clam", + + } + path = MetadataServiceClient.metadata_store_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_metadata_store_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "whelk" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = MetadataServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + + } + path = MetadataServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "oyster" + + expected = "folders/{folder}".format(folder=folder, ) + actual = MetadataServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + + } + path = MetadataServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "cuttlefish" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = MetadataServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + + } + path = MetadataServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "winkle" + + expected = "projects/{project}".format(project=project, ) + actual = MetadataServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + + } + path = MetadataServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "scallop" + location = "abalone" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = MetadataServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + + } + path = MetadataServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = MetadataServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.MetadataServiceTransport, '_prep_wrapped_messages') as prep: + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.MetadataServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = MetadataServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 37ae2b65e8..85cf790381 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.migration_service import ( - MigrationServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.migration_service import ( - MigrationServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceClient from google.cloud.aiplatform_v1beta1.services.migration_service import pagers from google.cloud.aiplatform_v1beta1.services.migration_service import transports from google.cloud.aiplatform_v1beta1.types import migratable_resource @@ -57,11 +53,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -72,53 +64,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert ( - MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) + assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + MigrationServiceClient, + MigrationServiceAsyncClient, +]) def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + MigrationServiceClient, + MigrationServiceAsyncClient, +]) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -128,7 +103,7 @@ def test_migration_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_migration_service_client_get_transport_class(): @@ -142,44 +117,29 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - MigrationServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceClient), -) -@mock.patch.object( - MigrationServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceAsyncClient), -) -def test_migration_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +def test_migration_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -195,7 +155,7 @@ def test_migration_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -211,7 +171,7 @@ def test_migration_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -231,15 +191,13 @@ def test_migration_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -252,62 +210,26 @@ def test_migration_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - MigrationServiceClient, - transports.MigrationServiceGrpcTransport, - "grpc", - "true", - ), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - MigrationServiceClient, - transports.MigrationServiceGrpcTransport, - "grpc", - "false", - ), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - MigrationServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceClient), -) -@mock.patch.object( - MigrationServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceAsyncClient), -) + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -330,18 +252,10 @@ def test_migration_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -362,14 +276,9 @@ def test_migration_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -383,23 +292,16 @@ def test_migration_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_migration_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -412,24 +314,16 @@ def test_migration_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_migration_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -444,12 +338,10 @@ def test_migration_service_client_client_options_credentials_file( def test_migration_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -462,12 +354,10 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources( - transport: str = "grpc", - request_type=migration_service.SearchMigratableResourcesRequest, -): +def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -476,11 +366,12 @@ def test_search_migratable_resources( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.search_migratable_resources(request) @@ -495,7 +386,7 @@ def test_search_migratable_resources( assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_search_migratable_resources_from_dict(): @@ -506,27 +397,25 @@ def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.SearchMigratableResourcesRequest() - @pytest.mark.asyncio -async def test_search_migratable_resources_async( - transport: str = "grpc_asyncio", - request_type=migration_service.SearchMigratableResourcesRequest, -): +async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -535,14 +424,12 @@ async def test_search_migratable_resources_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( + next_page_token='next_page_token_value', + )) response = await client.search_migratable_resources(request) @@ -555,7 +442,7 @@ async def test_search_migratable_resources_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -564,17 +451,19 @@ async def test_search_migratable_resources_async_from_dict(): def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -586,7 +475,10 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -598,15 +490,13 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse() - ) + type(client.transport.search_migratable_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) await client.search_migratable_resources(request) @@ -617,39 +507,49 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources(parent="parent_value",) + client.search_migratable_resources( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', ) @@ -661,24 +561,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources(parent="parent_value",) + response = await client.search_migratable_resources( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio @@ -691,17 +591,20 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -710,14 +613,17 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -730,7 +636,9 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.search_migratable_resources(request={}) @@ -738,18 +646,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, migratable_resource.MigratableResource) for i in results - ) - + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in results) def test_search_migratable_resources_pages(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -758,14 +666,17 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -776,20 +687,19 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -798,14 +708,17 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -816,27 +729,25 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, migratable_resource.MigratableResource) for i in responses - ) - + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in responses) @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -845,14 +756,17 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -865,15 +779,14 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources( - transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest -): +def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -882,10 +795,10 @@ def test_batch_migrate_resources( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.batch_migrate_resources(request) @@ -907,27 +820,25 @@ def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.BatchMigrateResourcesRequest() - @pytest.mark.asyncio -async def test_batch_migrate_resources_async( - transport: str = "grpc_asyncio", - request_type=migration_service.BatchMigrateResourcesRequest, -): +async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -936,11 +847,11 @@ async def test_batch_migrate_resources_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.batch_migrate_resources(request) @@ -961,18 +872,20 @@ async def test_batch_migrate_resources_async_from_dict(): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.batch_migrate_resources(request) @@ -983,7 +896,10 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -995,15 +911,13 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.batch_migrate_resources(request) @@ -1014,30 +928,29 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) # Establish that the underlying call was made with the expected @@ -1045,33 +958,23 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].migrate_resource_requests == [ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ] + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) @@ -1083,25 +986,19 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) # Establish that the underlying call was made with the expected @@ -1109,15 +1006,9 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].migrate_resource_requests == [ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ] + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] @pytest.mark.asyncio @@ -1131,14 +1022,8 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) @@ -1149,7 +1034,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1168,7 +1054,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1196,16 +1083,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1213,8 +1097,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.MigrationServiceGrpcTransport, + ) def test_migration_service_base_transport_error(): @@ -1222,15 +1111,13 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1239,9 +1126,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "search_migratable_resources", - "batch_migrate_resources", - ) + 'search_migratable_resources', + 'batch_migrate_resources', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1254,28 +1141,23 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1284,11 +1166,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -1296,25 +1178,19 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) -def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +def test_migration_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1323,13 +1199,15 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_ transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1344,40 +1222,38 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_ with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_migration_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1385,11 +1261,12 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1398,22 +1275,12 @@ def test_migration_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1422,7 +1289,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1438,7 +1305,9 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1452,23 +1321,17 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) -def test_migration_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +def test_migration_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1485,7 +1348,9 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1498,12 +1363,16 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1511,12 +1380,16 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1527,20 +1400,17 @@ def test_annotated_dataset_path(): dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( - project=project, dataset=dataset, annotated_dataset=annotated_dataset, - ) - actual = MigrationServiceClient.annotated_dataset_path( - project, dataset, annotated_dataset - ) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", + } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1548,22 +1418,22 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual - def test_dataset_path(): project = "cuttlefish" - dataset = "mussel" + location = "mussel" + dataset = "winkle" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, - ) - actual = MigrationServiceClient.dataset_path(project, dataset) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1571,24 +1441,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_dataset_path(): - project = "scallop" - location = "abalone" - dataset = "squid" + project = "squid" + location = "clam" + dataset = "whelk" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", - "dataset": "octopus", + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1596,24 +1464,20 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_dataset_path(): - project = "oyster" - location = "nudibranch" - dataset = "cuttlefish" + project = "cuttlefish" + dataset = "mussel" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", - "dataset": "nautilus", + "project": "winkle", + "dataset": "nautilus", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1621,24 +1485,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", + "project": "clam", + "location": "whelk", + "model": "octopus", + } path = MigrationServiceClient.model_path(**expected) @@ -1646,24 +1508,22 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual - def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", + "project": "mussel", + "location": "winkle", + "model": "nautilus", + } path = MigrationServiceClient.model_path(**expected) @@ -1671,24 +1531,22 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual - def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format( - project=project, model=model, version=version, - ) + expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", + "project": "clam", + "model": "whelk", + "version": "octopus", + } path = MigrationServiceClient.version_path(**expected) @@ -1696,20 +1554,18 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", + "billing_account": "nudibranch", + } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1717,18 +1573,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", + "folder": "mussel", + } path = MigrationServiceClient.common_folder_path(**expected) @@ -1736,18 +1592,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", + "organization": "nautilus", + } path = MigrationServiceClient.common_organization_path(**expected) @@ -1755,18 +1611,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", + "project": "abalone", + } path = MigrationServiceClient.common_project_path(**expected) @@ -1774,22 +1630,20 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", + "project": "whelk", + "location": "octopus", + } path = MigrationServiceClient.common_location_path(**expected) @@ -1801,19 +1655,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.MigrationServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.MigrationServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index 51cbd4583f..ffe3ecd828 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -35,9 +35,7 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.model_service import ( - ModelServiceAsyncClient, -) +from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceAsyncClient from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.services.model_service import transports @@ -68,11 +66,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -83,49 +77,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) + assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [ModelServiceClient, ModelServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + ModelServiceClient, + ModelServiceAsyncClient, +]) def test_model_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [ModelServiceClient, ModelServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + ModelServiceClient, + ModelServiceAsyncClient, +]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -135,7 +116,7 @@ def test_model_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_model_service_client_get_transport_class(): @@ -149,42 +130,29 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) -def test_model_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -200,7 +168,7 @@ def test_model_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -216,7 +184,7 @@ def test_model_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -236,15 +204,13 @@ def test_model_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -257,50 +223,26 @@ def test_model_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -323,18 +265,10 @@ def test_model_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -355,14 +289,9 @@ def test_model_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -376,23 +305,16 @@ def test_model_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -405,24 +327,16 @@ def test_model_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -437,11 +351,11 @@ def test_model_service_client_client_options_credentials_file( def test_model_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + client = ModelServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -453,11 +367,10 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model( - transport: str = "grpc", request_type=model_service.UploadModelRequest -): +def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -465,9 +378,11 @@ def test_upload_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.upload_model(request) @@ -489,24 +404,25 @@ def test_upload_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: client.upload_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UploadModelRequest() - @pytest.mark.asyncio -async def test_upload_model_async( - transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest -): +async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UploadModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -514,10 +430,12 @@ async def test_upload_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.upload_model(request) @@ -538,16 +456,20 @@ async def test_upload_model_async_from_dict(): def test_upload_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.upload_model(request) @@ -558,23 +480,28 @@ def test_upload_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.upload_model(request) @@ -585,21 +512,29 @@ async def test_upload_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_upload_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -607,40 +542,47 @@ def test_upload_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') def test_upload_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.upload_model( model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) @pytest.mark.asyncio async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -648,28 +590,31 @@ async def test_upload_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') @pytest.mark.asyncio async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.upload_model( model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) -def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): +def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -677,21 +622,31 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + ) response = client.get_model(request) @@ -706,31 +661,25 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR assert isinstance(response, model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_model_from_dict(): @@ -741,24 +690,25 @@ def test_get_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: client.get_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelRequest() - @pytest.mark.asyncio -async def test_get_model_async( - transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest -): +async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -766,28 +716,22 @@ async def test_get_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) response = await client.get_model(request) @@ -800,31 +744,25 @@ async def test_get_model_async( # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -833,15 +771,19 @@ async def test_get_model_async_from_dict(): def test_get_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: call.return_value = model.Model() client.get_model(request) @@ -853,20 +795,27 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -878,79 +827,99 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model(name="name_value",) + client.get_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model( - model_service.GetModelRequest(), name="name_value", + model_service.GetModelRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model(name="name_value",) + response = await client.get_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model( - model_service.GetModelRequest(), name="name_value", + model_service.GetModelRequest(), + name='name_value', ) -def test_list_models( - transport: str = "grpc", request_type=model_service.ListModelsRequest -): +def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -958,10 +927,13 @@ def test_list_models( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_models(request) @@ -976,7 +948,7 @@ def test_list_models( assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_models_from_dict(): @@ -987,24 +959,25 @@ def test_list_models_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: client.list_models() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelsRequest() - @pytest.mark.asyncio -async def test_list_models_async( - transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest -): +async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1012,11 +985,13 @@ async def test_list_models_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_models(request) @@ -1029,7 +1004,7 @@ async def test_list_models_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1038,15 +1013,19 @@ async def test_list_models_async_from_dict(): def test_list_models_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: call.return_value = model_service.ListModelsResponse() client.list_models(request) @@ -1058,23 +1037,28 @@ def test_list_models_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) await client.list_models(request) @@ -1085,98 +1069,138 @@ async def test_list_models_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_models_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_models(parent="parent_value",) + client.list_models( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_models_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_models( - model_service.ListModelsRequest(), parent="parent_value", + model_service.ListModelsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_models(parent="parent_value",) + response = await client.list_models( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_models( - model_service.ListModelsRequest(), parent="parent_value", + model_service.ListModelsRequest(), + parent='parent_value', ) def test_list_models_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_models(request={}) @@ -1184,96 +1208,147 @@ def test_list_models_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model.Model) for i in results) - + assert all(isinstance(i, model.Model) + for i in results) def test_list_models_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_models_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) for i in responses) - + assert all(isinstance(i, model.Model) + for i in responses) @pytest.mark.asyncio async def test_list_models_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_models(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_update_model( - transport: str = "grpc", request_type=model_service.UpdateModelRequest -): +def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1281,21 +1356,31 @@ def test_update_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + ) response = client.update_model(request) @@ -1310,31 +1395,25 @@ def test_update_model( assert isinstance(response, gca_model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_model_from_dict(): @@ -1345,24 +1424,25 @@ def test_update_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: client.update_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UpdateModelRequest() - @pytest.mark.asyncio -async def test_update_model_async( - transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest -): +async def test_update_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1370,28 +1450,22 @@ async def test_update_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) response = await client.update_model(request) @@ -1404,31 +1478,25 @@ async def test_update_model_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' @pytest.mark.asyncio @@ -1437,15 +1505,19 @@ async def test_update_model_async_from_dict(): def test_update_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" + request.model.name = 'model.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: call.return_value = gca_model.Model() client.update_model(request) @@ -1457,20 +1529,27 @@ def test_update_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" + request.model.name = 'model.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) await client.update_model(request) @@ -1482,22 +1561,29 @@ async def test_update_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] def test_update_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1505,30 +1591,36 @@ def test_update_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() @@ -1536,8 +1628,8 @@ async def test_update_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1545,30 +1637,31 @@ async def test_update_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_delete_model( - transport: str = "grpc", request_type=model_service.DeleteModelRequest -): +def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1576,9 +1669,11 @@ def test_delete_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_model(request) @@ -1600,24 +1695,25 @@ def test_delete_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: client.delete_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.DeleteModelRequest() - @pytest.mark.asyncio -async def test_delete_model_async( - transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest -): +async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1625,10 +1721,12 @@ async def test_delete_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_model(request) @@ -1649,16 +1747,20 @@ async def test_delete_model_async_from_dict(): def test_delete_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_model(request) @@ -1669,23 +1771,28 @@ def test_delete_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_model(request) @@ -1696,81 +1803,101 @@ async def test_delete_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model(name="name_value",) + client.delete_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model( - model_service.DeleteModelRequest(), name="name_value", + model_service.DeleteModelRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_model(name="name_value",) + response = await client.delete_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model( - model_service.DeleteModelRequest(), name="name_value", + model_service.DeleteModelRequest(), + name='name_value', ) -def test_export_model( - transport: str = "grpc", request_type=model_service.ExportModelRequest -): +def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1778,9 +1905,11 @@ def test_export_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.export_model(request) @@ -1802,24 +1931,25 @@ def test_export_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: client.export_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ExportModelRequest() - @pytest.mark.asyncio -async def test_export_model_async( - transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest -): +async def test_export_model_async(transport: str = 'grpc_asyncio', request_type=model_service.ExportModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1827,10 +1957,12 @@ async def test_export_model_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.export_model(request) @@ -1851,16 +1983,20 @@ async def test_export_model_async_from_dict(): def test_export_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.export_model(request) @@ -1871,23 +2007,28 @@ def test_export_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.export_model(request) @@ -1898,24 +2039,29 @@ async def test_export_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_export_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) # Establish that the underlying call was made with the expected @@ -1923,47 +2069,47 @@ def test_export_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') def test_export_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_model( model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) @pytest.mark.asyncio async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) # Establish that the underlying call was made with the expected @@ -1971,34 +2117,31 @@ async def test_export_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') @pytest.mark.asyncio async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_model( model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) -def test_get_model_evaluation( - transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest -): +def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2007,13 +2150,16 @@ def test_get_model_evaluation( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + + slice_dimensions=['slice_dimensions_value'], + ) response = client.get_model_evaluation(request) @@ -2028,11 +2174,11 @@ def test_get_model_evaluation( assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' - assert response.slice_dimensions == ["slice_dimensions_value"] + assert response.slice_dimensions == ['slice_dimensions_value'] def test_get_model_evaluation_from_dict(): @@ -2043,27 +2189,25 @@ def test_get_model_evaluation_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: client.get_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationRequest() - @pytest.mark.asyncio -async def test_get_model_evaluation_async( - transport: str = "grpc_asyncio", - request_type=model_service.GetModelEvaluationRequest, -): +async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2072,16 +2216,14 @@ async def test_get_model_evaluation_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + slice_dimensions=['slice_dimensions_value'], + )) response = await client.get_model_evaluation(request) @@ -2094,11 +2236,11 @@ async def test_get_model_evaluation_async( # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' - assert response.slice_dimensions == ["slice_dimensions_value"] + assert response.slice_dimensions == ['slice_dimensions_value'] @pytest.mark.asyncio @@ -2107,17 +2249,19 @@ async def test_get_model_evaluation_async_from_dict(): def test_get_model_evaluation_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: call.return_value = model_evaluation.ModelEvaluation() client.get_model_evaluation(request) @@ -2129,25 +2273,28 @@ def test_get_model_evaluation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) + type(client.transport.get_model_evaluation), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) await client.get_model_evaluation(request) @@ -2158,85 +2305,99 @@ async def test_get_model_evaluation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_evaluation_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation(name="name_value",) + client.get_model_evaluation( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", + model_service.GetModelEvaluationRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation(name="name_value",) + response = await client.get_model_evaluation( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", + model_service.GetModelEvaluationRequest(), + name='name_value', ) -def test_list_model_evaluations( - transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest -): +def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2245,11 +2406,12 @@ def test_list_model_evaluations( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_model_evaluations(request) @@ -2264,7 +2426,7 @@ def test_list_model_evaluations( assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_model_evaluations_from_dict(): @@ -2275,27 +2437,25 @@ def test_list_model_evaluations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: client.list_model_evaluations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationsRequest() - @pytest.mark.asyncio -async def test_list_model_evaluations_async( - transport: str = "grpc_asyncio", - request_type=model_service.ListModelEvaluationsRequest, -): +async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationsRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2304,14 +2464,12 @@ async def test_list_model_evaluations_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_model_evaluations(request) @@ -2324,7 +2482,7 @@ async def test_list_model_evaluations_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2333,17 +2491,19 @@ async def test_list_model_evaluations_async_from_dict(): def test_list_model_evaluations_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: call.return_value = model_service.ListModelEvaluationsResponse() client.list_model_evaluations(request) @@ -2355,25 +2515,28 @@ def test_list_model_evaluations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) + type(client.transport.list_model_evaluations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) await client.list_model_evaluations(request) @@ -2384,87 +2547,104 @@ async def test_list_model_evaluations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_model_evaluations_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluations(parent="parent_value",) + client.list_model_evaluations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", + model_service.ListModelEvaluationsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluations(parent="parent_value",) + response = await client.list_model_evaluations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", + model_service.ListModelEvaluationsRequest(), + parent='parent_value', ) def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2473,14 +2653,17 @@ def test_list_model_evaluations_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2493,7 +2676,9 @@ def test_list_model_evaluations_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_model_evaluations(request={}) @@ -2501,16 +2686,18 @@ def test_list_model_evaluations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) - + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in results) def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2519,14 +2706,17 @@ def test_list_model_evaluations_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2537,20 +2727,19 @@ def test_list_model_evaluations_pages(): RuntimeError, ) pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2559,14 +2748,17 @@ async def test_list_model_evaluations_async_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2577,25 +2769,25 @@ async def test_list_model_evaluations_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) - + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in responses) @pytest.mark.asyncio async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2604,14 +2796,17 @@ async def test_list_model_evaluations_async_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2624,15 +2819,14 @@ async def test_list_model_evaluations_async_pages(): pages = [] async for page_ in (await client.list_model_evaluations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice( - transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest -): +def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2641,11 +2835,14 @@ def test_get_model_evaluation_slice( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + ) response = client.get_model_evaluation_slice(request) @@ -2660,9 +2857,9 @@ def test_get_model_evaluation_slice( assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' def test_get_model_evaluation_slice_from_dict(): @@ -2673,27 +2870,25 @@ def test_get_model_evaluation_slice_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: client.get_model_evaluation_slice() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationSliceRequest() - @pytest.mark.asyncio -async def test_get_model_evaluation_slice_async( - transport: str = "grpc_asyncio", - request_type=model_service.GetModelEvaluationSliceRequest, -): +async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationSliceRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2702,14 +2897,13 @@ async def test_get_model_evaluation_slice_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + )) response = await client.get_model_evaluation_slice(request) @@ -2722,9 +2916,9 @@ async def test_get_model_evaluation_slice_async( # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' @pytest.mark.asyncio @@ -2733,17 +2927,19 @@ async def test_get_model_evaluation_slice_async_from_dict(): def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: call.return_value = model_evaluation_slice.ModelEvaluationSlice() client.get_model_evaluation_slice(request) @@ -2755,25 +2951,28 @@ def test_get_model_evaluation_slice_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) await client.get_model_evaluation_slice(request) @@ -2784,85 +2983,99 @@ async def test_get_model_evaluation_slice_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation_slice(name="name_value",) + client.get_model_evaluation_slice( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", + model_service.GetModelEvaluationSliceRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice(name="name_value",) + response = await client.get_model_evaluation_slice( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", + model_service.GetModelEvaluationSliceRequest(), + name='name_value', ) -def test_list_model_evaluation_slices( - transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest -): +def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2871,11 +3084,12 @@ def test_list_model_evaluation_slices( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_model_evaluation_slices(request) @@ -2890,7 +3104,7 @@ def test_list_model_evaluation_slices( assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_model_evaluation_slices_from_dict(): @@ -2901,27 +3115,25 @@ def test_list_model_evaluation_slices_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: client.list_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationSlicesRequest() - @pytest.mark.asyncio -async def test_list_model_evaluation_slices_async( - transport: str = "grpc_asyncio", - request_type=model_service.ListModelEvaluationSlicesRequest, -): +async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationSlicesRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2930,14 +3142,12 @@ async def test_list_model_evaluation_slices_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( + next_page_token='next_page_token_value', + )) response = await client.list_model_evaluation_slices(request) @@ -2950,7 +3160,7 @@ async def test_list_model_evaluation_slices_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2959,17 +3169,19 @@ async def test_list_model_evaluation_slices_async_from_dict(): def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: call.return_value = model_service.ListModelEvaluationSlicesResponse() client.list_model_evaluation_slices(request) @@ -2981,25 +3193,28 @@ def test_list_model_evaluation_slices_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) await client.list_model_evaluation_slices(request) @@ -3010,87 +3225,104 @@ async def test_list_model_evaluation_slices_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluation_slices(parent="parent_value",) + client.list_model_evaluation_slices( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices(parent="parent_value",) + response = await client.list_model_evaluation_slices( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', ) def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3099,16 +3331,17 @@ def test_list_model_evaluation_slices_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3121,7 +3354,9 @@ def test_list_model_evaluation_slices_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_model_evaluation_slices(request={}) @@ -3129,18 +3364,18 @@ def test_list_model_evaluation_slices_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results - ) - + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in results) def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3149,16 +3384,17 @@ def test_list_model_evaluation_slices_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3169,20 +3405,19 @@ def test_list_model_evaluation_slices_pages(): RuntimeError, ) pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3191,16 +3426,17 @@ async def test_list_model_evaluation_slices_async_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3211,28 +3447,25 @@ async def test_list_model_evaluation_slices_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses - ) - + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses) @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3241,16 +3474,17 @@ async def test_list_model_evaluation_slices_async_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3261,11 +3495,9 @@ async def test_list_model_evaluation_slices_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( - await client.list_model_evaluation_slices(request={}) - ).pages: + async for page_ in (await client.list_model_evaluation_slices(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token @@ -3276,7 +3508,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3295,7 +3528,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -3323,16 +3557,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3340,8 +3571,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelServiceGrpcTransport, + ) def test_model_service_base_transport_error(): @@ -3349,15 +3585,13 @@ def test_model_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3366,17 +3600,17 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "upload_model", - "get_model", - "list_models", - "update_model", - "delete_model", - "export_model", - "get_model_evaluation", - "list_model_evaluations", - "get_model_evaluation_slice", - "list_model_evaluation_slices", - ) + 'upload_model', + 'get_model', + 'list_models', + 'update_model', + 'delete_model', + 'export_model', + 'get_model_evaluation', + 'list_model_evaluations', + 'get_model_evaluation_slice', + 'list_model_evaluation_slices', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3389,28 +3623,23 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -3419,11 +3648,11 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) ModelServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -3431,22 +3660,19 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3455,13 +3681,15 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_clas transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3476,40 +3704,38 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_clas with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_model_service_host_with_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_model_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3517,11 +3743,12 @@ def test_model_service_grpc_transport_channel(): def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3530,17 +3757,12 @@ def test_model_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3549,7 +3771,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3565,7 +3787,9 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3579,20 +3803,17 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3609,7 +3830,9 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3622,12 +3845,16 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): def test_model_service_grpc_lro_client(): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3635,12 +3862,16 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3651,18 +3882,17 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = ModelServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = ModelServiceClient.endpoint_path(**expected) @@ -3670,24 +3900,22 @@ def test_parse_endpoint_path(): actual = ModelServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = ModelServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = ModelServiceClient.model_path(**expected) @@ -3695,28 +3923,24 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual - def test_model_evaluation_path(): project = "squid" location = "clam" model = "whelk" evaluation = "octopus" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( - project=project, location=location, model=model, evaluation=evaluation, - ) - actual = ModelServiceClient.model_evaluation_path( - project, location, model, evaluation - ) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) assert expected == actual def test_parse_model_evaluation_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "model": "cuttlefish", - "evaluation": "mussel", + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", + } path = ModelServiceClient.model_evaluation_path(**expected) @@ -3724,7 +3948,6 @@ def test_parse_model_evaluation_path(): actual = ModelServiceClient.parse_model_evaluation_path(path) assert expected == actual - def test_model_evaluation_slice_path(): project = "winkle" location = "nautilus" @@ -3732,26 +3955,19 @@ def test_model_evaluation_slice_path(): evaluation = "abalone" slice = "squid" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( - project=project, - location=location, - model=model, - evaluation=evaluation, - slice=slice, - ) - actual = ModelServiceClient.model_evaluation_slice_path( - project, location, model, evaluation, slice - ) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) assert expected == actual def test_parse_model_evaluation_slice_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - "evaluation": "oyster", - "slice": "nudibranch", + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", + } path = ModelServiceClient.model_evaluation_slice_path(**expected) @@ -3759,26 +3975,22 @@ def test_parse_model_evaluation_slice_path(): actual = ModelServiceClient.parse_model_evaluation_slice_path(path) assert expected == actual - def test_training_pipeline_path(): project = "cuttlefish" location = "mussel" training_pipeline = "winkle" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = ModelServiceClient.training_pipeline_path( - project, location, training_pipeline - ) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "nautilus", - "location": "scallop", - "training_pipeline": "abalone", + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", + } path = ModelServiceClient.training_pipeline_path(**expected) @@ -3786,20 +3998,18 @@ def test_parse_training_pipeline_path(): actual = ModelServiceClient.parse_training_pipeline_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "clam", + } path = ModelServiceClient.common_billing_account_path(**expected) @@ -3807,18 +4017,18 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "octopus", + } path = ModelServiceClient.common_folder_path(**expected) @@ -3826,18 +4036,18 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nudibranch", + } path = ModelServiceClient.common_organization_path(**expected) @@ -3845,18 +4055,18 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = ModelServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "mussel", + } path = ModelServiceClient.common_project_path(**expected) @@ -3864,22 +4074,20 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "scallop", + "location": "abalone", + } path = ModelServiceClient.common_location_path(**expected) @@ -3891,19 +4099,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index d1d65aecbd..be11879c35 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceClient from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports from google.cloud.aiplatform_v1beta1.types import deployed_model_ref @@ -54,9 +50,7 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import any_pb2 as gp_any # type: ignore @@ -74,11 +68,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -89,52 +79,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + PipelineServiceClient, + PipelineServiceAsyncClient, +]) def test_pipeline_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + PipelineServiceClient, + PipelineServiceAsyncClient, +]) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -144,7 +118,7 @@ def test_pipeline_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_pipeline_service_client_get_transport_class(): @@ -158,44 +132,29 @@ def test_pipeline_service_client_get_transport_class(): assert transport == transports.PipelineServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) -def test_pipeline_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) +def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -211,7 +170,7 @@ def test_pipeline_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -227,7 +186,7 @@ def test_pipeline_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -247,15 +206,13 @@ def test_pipeline_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -268,62 +225,26 @@ def test_pipeline_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "true", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "false", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -346,18 +267,10 @@ def test_pipeline_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -378,14 +291,9 @@ def test_pipeline_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -399,23 +307,16 @@ def test_pipeline_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -428,24 +329,16 @@ def test_pipeline_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -460,12 +353,10 @@ def test_pipeline_service_client_client_options_credentials_file( def test_pipeline_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = PipelineServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -478,11 +369,10 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest -): +def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -491,14 +381,18 @@ def test_create_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) response = client.create_training_pipeline(request) @@ -513,11 +407,11 @@ def test_create_training_pipeline( assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -530,27 +424,25 @@ def test_create_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: client.create_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CreateTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_create_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.CreateTrainingPipelineRequest, -): +async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CreateTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -559,17 +451,15 @@ async def test_create_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) response = await client.create_training_pipeline(request) @@ -582,11 +472,11 @@ async def test_create_training_pipeline_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -597,17 +487,19 @@ async def test_create_training_pipeline_async_from_dict(): def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: call.return_value = gca_training_pipeline.TrainingPipeline() client.create_training_pipeline(request) @@ -619,25 +511,28 @@ def test_create_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) + type(client.transport.create_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) await client.create_training_pipeline(request) @@ -648,24 +543,29 @@ async def test_create_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -673,45 +573,45 @@ def test_create_training_pipeline_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -719,32 +619,31 @@ async def test_create_training_pipeline_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') @pytest.mark.asyncio async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) -def test_get_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest -): +def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -753,14 +652,18 @@ def test_get_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) response = client.get_training_pipeline(request) @@ -775,11 +678,11 @@ def test_get_training_pipeline( assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -792,27 +695,25 @@ def test_get_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: client.get_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.GetTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_get_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.GetTrainingPipelineRequest, -): +async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.GetTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -821,17 +722,15 @@ async def test_get_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) response = await client.get_training_pipeline(request) @@ -844,11 +743,11 @@ async def test_get_training_pipeline_async( # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -859,17 +758,19 @@ async def test_get_training_pipeline_async_from_dict(): def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: call.return_value = training_pipeline.TrainingPipeline() client.get_training_pipeline(request) @@ -881,25 +782,28 @@ def test_get_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) + type(client.transport.get_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) await client.get_training_pipeline(request) @@ -910,85 +814,99 @@ async def test_get_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_training_pipeline(name="name_value",) + client.get_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_training_pipeline(name="name_value",) + response = await client.get_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', ) -def test_list_training_pipelines( - transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest -): +def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -997,11 +915,12 @@ def test_list_training_pipelines( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_training_pipelines(request) @@ -1016,7 +935,7 @@ def test_list_training_pipelines( assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_training_pipelines_from_dict(): @@ -1027,27 +946,25 @@ def test_list_training_pipelines_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: client.list_training_pipelines() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.ListTrainingPipelinesRequest() - @pytest.mark.asyncio -async def test_list_training_pipelines_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.ListTrainingPipelinesRequest, -): +async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.ListTrainingPipelinesRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1056,14 +973,12 @@ async def test_list_training_pipelines_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( + next_page_token='next_page_token_value', + )) response = await client.list_training_pipelines(request) @@ -1076,7 +991,7 @@ async def test_list_training_pipelines_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1085,17 +1000,19 @@ async def test_list_training_pipelines_async_from_dict(): def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: call.return_value = pipeline_service.ListTrainingPipelinesResponse() client.list_training_pipelines(request) @@ -1107,25 +1024,28 @@ def test_list_training_pipelines_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) + type(client.transport.list_training_pipelines), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) await client.list_training_pipelines(request) @@ -1136,87 +1056,104 @@ async def test_list_training_pipelines_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_training_pipelines_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_training_pipelines(parent="parent_value",) + client.list_training_pipelines( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_training_pipelines(parent="parent_value",) + response = await client.list_training_pipelines( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', ) def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1225,14 +1162,17 @@ def test_list_training_pipelines_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1245,7 +1185,9 @@ def test_list_training_pipelines_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_training_pipelines(request={}) @@ -1253,16 +1195,18 @@ def test_list_training_pipelines_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) - + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in results) def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1271,14 +1215,17 @@ def test_list_training_pipelines_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1289,20 +1236,19 @@ def test_list_training_pipelines_pages(): RuntimeError, ) pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1311,14 +1257,17 @@ async def test_list_training_pipelines_async_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1329,25 +1278,25 @@ async def test_list_training_pipelines_async_pager(): RuntimeError, ) async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) - + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in responses) @pytest.mark.asyncio async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1356,14 +1305,17 @@ async def test_list_training_pipelines_async_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1376,15 +1328,14 @@ async def test_list_training_pipelines_async_pages(): pages = [] async for page_ in (await client.list_training_pipelines(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest -): +def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1393,10 +1344,10 @@ def test_delete_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_training_pipeline(request) @@ -1418,27 +1369,25 @@ def test_delete_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: client.delete_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_delete_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.DeleteTrainingPipelineRequest, -): +async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.DeleteTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1447,11 +1396,11 @@ async def test_delete_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_training_pipeline(request) @@ -1472,18 +1421,20 @@ async def test_delete_training_pipeline_async_from_dict(): def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_training_pipeline(request) @@ -1494,25 +1445,28 @@ def test_delete_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_training_pipeline(request) @@ -1523,85 +1477,101 @@ async def test_delete_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_training_pipeline(name="name_value",) + client.delete_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_training_pipeline(name="name_value",) + response = await client.delete_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', ) -def test_cancel_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest -): +def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1610,8 +1580,8 @@ def test_cancel_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1635,27 +1605,25 @@ def test_cancel_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: client.cancel_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CancelTrainingPipelineRequest() - @pytest.mark.asyncio -async def test_cancel_training_pipeline_async( - transport: str = "grpc_asyncio", - request_type=pipeline_service.CancelTrainingPipelineRequest, -): +async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CancelTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1664,8 +1632,8 @@ async def test_cancel_training_pipeline_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1687,17 +1655,19 @@ async def test_cancel_training_pipeline_async_from_dict(): def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: call.return_value = None client.cancel_training_pipeline(request) @@ -1709,22 +1679,27 @@ def test_cancel_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_training_pipeline(request) @@ -1736,75 +1711,92 @@ async def test_cancel_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_training_pipeline(name="name_value",) + client.cancel_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_training_pipeline(name="name_value",) + response = await client.cancel_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', ) @@ -1815,7 +1807,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1834,7 +1827,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1862,16 +1856,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1879,8 +1870,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PipelineServiceGrpcTransport, + ) def test_pipeline_service_base_transport_error(): @@ -1888,15 +1884,13 @@ def test_pipeline_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_pipeline_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1905,12 +1899,12 @@ def test_pipeline_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_training_pipeline", - "get_training_pipeline", - "list_training_pipelines", - "delete_training_pipeline", - "cancel_training_pipeline", - ) + 'create_training_pipeline', + 'get_training_pipeline', + 'list_training_pipelines', + 'delete_training_pipeline', + 'cancel_training_pipeline', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1923,28 +1917,23 @@ def test_pipeline_service_base_transport(): def test_pipeline_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_pipeline_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport() @@ -1953,11 +1942,11 @@ def test_pipeline_service_base_transport_with_adc(): def test_pipeline_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PipelineServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -1965,25 +1954,19 @@ def test_pipeline_service_auth_adc(): def test_pipeline_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1992,13 +1975,15 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_c transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2013,40 +1998,38 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_c with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_pipeline_service_host_with_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_pipeline_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2054,11 +2037,12 @@ def test_pipeline_service_grpc_transport_channel(): def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2067,22 +2051,12 @@ def test_pipeline_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2091,7 +2065,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2107,7 +2081,9 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2121,23 +2097,17 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +def test_pipeline_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2154,7 +2124,9 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2167,12 +2139,16 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): def test_pipeline_service_grpc_lro_client(): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2180,12 +2156,16 @@ def test_pipeline_service_grpc_lro_client(): def test_pipeline_service_grpc_lro_async_client(): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2196,18 +2176,17 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = PipelineServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = PipelineServiceClient.endpoint_path(**expected) @@ -2215,24 +2194,22 @@ def test_parse_endpoint_path(): actual = PipelineServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = PipelineServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = PipelineServiceClient.model_path(**expected) @@ -2240,26 +2217,22 @@ def test_parse_model_path(): actual = PipelineServiceClient.parse_model_path(path) assert expected == actual - def test_training_pipeline_path(): project = "squid" location = "clam" training_pipeline = "whelk" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = PipelineServiceClient.training_pipeline_path( - project, location, training_pipeline - ) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", + } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2267,20 +2240,18 @@ def test_parse_training_pipeline_path(): actual = PipelineServiceClient.parse_training_pipeline_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = PipelineServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "mussel", + } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2288,18 +2259,18 @@ def test_parse_common_billing_account_path(): actual = PipelineServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = PipelineServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "nautilus", + } path = PipelineServiceClient.common_folder_path(**expected) @@ -2307,18 +2278,18 @@ def test_parse_common_folder_path(): actual = PipelineServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = PipelineServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "abalone", + } path = PipelineServiceClient.common_organization_path(**expected) @@ -2326,18 +2297,18 @@ def test_parse_common_organization_path(): actual = PipelineServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = PipelineServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "clam", + } path = PipelineServiceClient.common_project_path(**expected) @@ -2345,22 +2316,20 @@ def test_parse_common_project_path(): actual = PipelineServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = PipelineServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "oyster", + "location": "nudibranch", + } path = PipelineServiceClient.common_location_path(**expected) @@ -2372,19 +2341,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: transport_class = PipelineServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index 879a0a69d5..06ec395aaf 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceClient from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -60,11 +56,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -75,53 +67,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + SpecialistPoolServiceClient, + SpecialistPoolServiceAsyncClient, +]) def test_specialist_pool_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + SpecialistPoolServiceClient, + SpecialistPoolServiceAsyncClient, +]) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -131,7 +106,7 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_specialist_pool_service_client_get_transport_class(): @@ -145,48 +120,29 @@ def test_specialist_pool_service_client_get_transport_class(): assert transport == transports.SpecialistPoolServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) -def test_specialist_pool_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) +def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -202,7 +158,7 @@ def test_specialist_pool_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -218,7 +174,7 @@ def test_specialist_pool_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -238,15 +194,13 @@ def test_specialist_pool_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -259,62 +213,26 @@ def test_specialist_pool_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "true", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "false", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -337,18 +255,10 @@ def test_specialist_pool_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -369,14 +279,9 @@ def test_specialist_pool_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -390,27 +295,16 @@ def test_specialist_pool_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -423,28 +317,16 @@ def test_specialist_pool_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -459,12 +341,10 @@ def test_specialist_pool_service_client_client_options_credentials_file( def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = SpecialistPoolServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -477,12 +357,10 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): +def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -491,10 +369,10 @@ def test_create_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_specialist_pool(request) @@ -516,27 +394,25 @@ def test_create_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: client.create_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_create_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): +async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.CreateSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -545,11 +421,11 @@ async def test_create_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_specialist_pool(request) @@ -577,13 +453,13 @@ def test_create_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_specialist_pool(request) @@ -594,7 +470,10 @@ def test_create_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -606,15 +485,13 @@ async def test_create_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_specialist_pool(request) @@ -625,7 +502,10 @@ async def test_create_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_specialist_pool_flattened(): @@ -635,16 +515,16 @@ def test_create_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -652,11 +532,9 @@ def test_create_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') def test_create_specialist_pool_flattened_error(): @@ -669,8 +547,8 @@ def test_create_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) @@ -682,19 +560,19 @@ async def test_create_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -702,11 +580,9 @@ async def test_create_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') @pytest.mark.asyncio @@ -720,17 +596,15 @@ async def test_create_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) -def test_get_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): +def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -739,15 +613,20 @@ def test_get_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + + specialist_manager_emails=['specialist_manager_emails_value'], + + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + ) response = client.get_specialist_pool(request) @@ -762,15 +641,15 @@ def test_get_specialist_pool( assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] def test_get_specialist_pool_from_dict(): @@ -781,27 +660,25 @@ def test_get_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: client.get_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_get_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): +async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.GetSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -810,18 +687,16 @@ async def test_get_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", - specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( + name='name_value', + display_name='display_name_value', + specialist_managers_count=2662, + specialist_manager_emails=['specialist_manager_emails_value'], + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + )) response = await client.get_specialist_pool(request) @@ -834,15 +709,15 @@ async def test_get_specialist_pool_async( # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] @pytest.mark.asyncio @@ -858,12 +733,12 @@ def test_get_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: call.return_value = specialist_pool.SpecialistPool() client.get_specialist_pool(request) @@ -875,7 +750,10 @@ def test_get_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -887,15 +765,13 @@ async def test_get_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) + type(client.transport.get_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) await client.get_specialist_pool(request) @@ -906,7 +782,10 @@ async def test_get_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_specialist_pool_flattened(): @@ -916,21 +795,23 @@ def test_get_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_specialist_pool(name="name_value",) + client.get_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_specialist_pool_flattened_error(): @@ -942,7 +823,8 @@ def test_get_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', ) @@ -954,24 +836,24 @@ async def test_get_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_specialist_pool(name="name_value",) + response = await client.get_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio @@ -984,16 +866,15 @@ async def test_get_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', ) -def test_list_specialist_pools( - transport: str = "grpc", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): +def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1002,11 +883,12 @@ def test_list_specialist_pools( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_specialist_pools(request) @@ -1021,7 +903,7 @@ def test_list_specialist_pools( assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_specialist_pools_from_dict(): @@ -1032,27 +914,25 @@ def test_list_specialist_pools_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: client.list_specialist_pools() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() - @pytest.mark.asyncio -async def test_list_specialist_pools_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): +async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.ListSpecialistPoolsRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1061,14 +941,12 @@ async def test_list_specialist_pools_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_specialist_pools(request) @@ -1081,7 +959,7 @@ async def test_list_specialist_pools_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1097,12 +975,12 @@ def test_list_specialist_pools_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() client.list_specialist_pools(request) @@ -1114,7 +992,10 @@ def test_list_specialist_pools_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1126,15 +1007,13 @@ async def test_list_specialist_pools_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) + type(client.transport.list_specialist_pools), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) await client.list_specialist_pools(request) @@ -1145,7 +1024,10 @@ async def test_list_specialist_pools_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_specialist_pools_flattened(): @@ -1155,21 +1037,23 @@ def test_list_specialist_pools_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_specialist_pools(parent="parent_value",) + client.list_specialist_pools( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_specialist_pools_flattened_error(): @@ -1181,7 +1065,8 @@ def test_list_specialist_pools_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', ) @@ -1193,24 +1078,24 @@ async def test_list_specialist_pools_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_specialist_pools(parent="parent_value",) + response = await client.list_specialist_pools( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio @@ -1223,17 +1108,20 @@ async def test_list_specialist_pools_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', ) def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1242,14 +1130,17 @@ def test_list_specialist_pools_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1262,7 +1153,9 @@ def test_list_specialist_pools_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_specialist_pools(request={}) @@ -1270,16 +1163,18 @@ def test_list_specialist_pools_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) - + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in results) def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1288,14 +1183,17 @@ def test_list_specialist_pools_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1306,10 +1204,9 @@ def test_list_specialist_pools_pages(): RuntimeError, ) pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_specialist_pools_async_pager(): client = SpecialistPoolServiceAsyncClient( @@ -1318,10 +1215,8 @@ async def test_list_specialist_pools_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1330,14 +1225,17 @@ async def test_list_specialist_pools_async_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1348,14 +1246,14 @@ async def test_list_specialist_pools_async_pager(): RuntimeError, ) async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) - + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in responses) @pytest.mark.asyncio async def test_list_specialist_pools_async_pages(): @@ -1365,10 +1263,8 @@ async def test_list_specialist_pools_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1377,14 +1273,17 @@ async def test_list_specialist_pools_async_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1397,16 +1296,14 @@ async def test_list_specialist_pools_async_pages(): pages = [] async for page_ in (await client.list_specialist_pools(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): +def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1415,10 +1312,10 @@ def test_delete_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_specialist_pool(request) @@ -1440,27 +1337,25 @@ def test_delete_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: client.delete_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_delete_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): +async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1469,11 +1364,11 @@ async def test_delete_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_specialist_pool(request) @@ -1501,13 +1396,13 @@ def test_delete_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_specialist_pool(request) @@ -1518,7 +1413,10 @@ def test_delete_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1530,15 +1428,13 @@ async def test_delete_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_specialist_pool(request) @@ -1549,7 +1445,10 @@ async def test_delete_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_specialist_pool_flattened(): @@ -1559,21 +1458,23 @@ def test_delete_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_specialist_pool(name="name_value",) + client.delete_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_specialist_pool_flattened_error(): @@ -1585,7 +1486,8 @@ def test_delete_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', ) @@ -1597,24 +1499,26 @@ async def test_delete_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_specialist_pool(name="name_value",) + response = await client.delete_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio @@ -1627,16 +1531,15 @@ async def test_delete_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', ) -def test_update_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): +def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1645,10 +1548,10 @@ def test_update_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.update_specialist_pool(request) @@ -1670,27 +1573,25 @@ def test_update_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: client.update_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() - @pytest.mark.asyncio -async def test_update_specialist_pool_async( - transport: str = "grpc_asyncio", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): +async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1699,11 +1600,11 @@ async def test_update_specialist_pool_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.update_specialist_pool(request) @@ -1731,13 +1632,13 @@ def test_update_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" + request.specialist_pool.name = 'specialist_pool.name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.update_specialist_pool(request) @@ -1749,9 +1650,9 @@ def test_update_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1763,15 +1664,13 @@ async def test_update_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" + request.specialist_pool.name = 'specialist_pool.name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.update_specialist_pool(request) @@ -1783,9 +1682,9 @@ async def test_update_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] def test_update_specialist_pool_flattened(): @@ -1795,16 +1694,16 @@ def test_update_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1812,11 +1711,9 @@ def test_update_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_specialist_pool_flattened_error(): @@ -1829,8 +1726,8 @@ def test_update_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @@ -1842,19 +1739,19 @@ async def test_update_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1862,11 +1759,9 @@ async def test_update_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio @@ -1880,8 +1775,8 @@ async def test_update_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @@ -1892,7 +1787,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1911,7 +1807,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1939,16 +1836,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1959,7 +1853,10 @@ def test_transport_grpc_default(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), ) - assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) + assert isinstance( + client.transport, + transports.SpecialistPoolServiceGrpcTransport, + ) def test_specialist_pool_service_base_transport_error(): @@ -1967,15 +1864,13 @@ def test_specialist_pool_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_specialist_pool_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1984,12 +1879,12 @@ def test_specialist_pool_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_specialist_pool", - "get_specialist_pool", - "list_specialist_pools", - "delete_specialist_pool", - "update_specialist_pool", - ) + 'create_specialist_pool', + 'get_specialist_pool', + 'list_specialist_pools', + 'delete_specialist_pool', + 'update_specialist_pool', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2002,28 +1897,23 @@ def test_specialist_pool_service_base_transport(): def test_specialist_pool_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_specialist_pool_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport() @@ -2032,11 +1922,11 @@ def test_specialist_pool_service_base_transport_with_adc(): def test_specialist_pool_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) SpecialistPoolServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -2044,26 +1934,18 @@ def test_specialist_pool_service_auth_adc(): def test_specialist_pool_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( - transport_class, + transport_class ): cred = credentials.AnonymousCredentials() @@ -2073,13 +1955,15 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2094,40 +1978,38 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_specialist_pool_service_host_with_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2135,11 +2017,12 @@ def test_specialist_pool_service_grpc_transport_channel(): def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2148,22 +2031,12 @@ def test_specialist_pool_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2172,7 +2045,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2188,7 +2061,9 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2202,23 +2077,17 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) -def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +def test_specialist_pool_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2235,7 +2104,9 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2248,12 +2119,16 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class def test_specialist_pool_service_grpc_lro_client(): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2261,12 +2136,16 @@ def test_specialist_pool_service_grpc_lro_client(): def test_specialist_pool_service_grpc_lro_async_client(): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2277,20 +2156,17 @@ def test_specialist_pool_path(): location = "clam" specialist_pool = "whelk" - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) - actual = SpecialistPoolServiceClient.specialist_pool_path( - project, location, specialist_pool - ) + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) assert expected == actual def test_parse_specialist_pool_path(): expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", + } path = SpecialistPoolServiceClient.specialist_pool_path(**expected) @@ -2298,20 +2174,18 @@ def test_parse_specialist_pool_path(): actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "mussel", + } path = SpecialistPoolServiceClient.common_billing_account_path(**expected) @@ -2319,18 +2193,18 @@ def test_parse_common_billing_account_path(): actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = SpecialistPoolServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "nautilus", + } path = SpecialistPoolServiceClient.common_folder_path(**expected) @@ -2338,18 +2212,18 @@ def test_parse_common_folder_path(): actual = SpecialistPoolServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = SpecialistPoolServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "abalone", + } path = SpecialistPoolServiceClient.common_organization_path(**expected) @@ -2357,18 +2231,18 @@ def test_parse_common_organization_path(): actual = SpecialistPoolServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = SpecialistPoolServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "clam", + } path = SpecialistPoolServiceClient.common_project_path(**expected) @@ -2376,22 +2250,20 @@ def test_parse_common_project_path(): actual = SpecialistPoolServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = SpecialistPoolServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "oyster", + "location": "nudibranch", + } path = SpecialistPoolServiceClient.common_location_path(**expected) @@ -2403,19 +2275,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: transport_class = SpecialistPoolServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py index 5f1aec70ab..3370e5011e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py @@ -35,9 +35,7 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.vizier_service import ( - VizierServiceAsyncClient, -) +from google.cloud.aiplatform_v1beta1.services.vizier_service import VizierServiceAsyncClient from google.cloud.aiplatform_v1beta1.services.vizier_service import VizierServiceClient from google.cloud.aiplatform_v1beta1.services.vizier_service import pagers from google.cloud.aiplatform_v1beta1.services.vizier_service import transports @@ -46,6 +44,7 @@ from google.cloud.aiplatform_v1beta1.types import vizier_service from google.longrunning import operations_pb2 from google.oauth2 import service_account +from google.protobuf import duration_pb2 as duration # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -58,11 +57,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -73,52 +68,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert VizierServiceClient._get_default_mtls_endpoint(None) is None - assert ( - VizierServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - VizierServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - VizierServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - VizierServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - VizierServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert VizierServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert VizierServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert VizierServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert VizierServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert VizierServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [VizierServiceClient, VizierServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + VizierServiceClient, + VizierServiceAsyncClient, +]) def test_vizier_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' -@pytest.mark.parametrize( - "client_class", [VizierServiceClient, VizierServiceAsyncClient,], -) +@pytest.mark.parametrize("client_class", [ + VizierServiceClient, + VizierServiceAsyncClient, +]) def test_vizier_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -128,7 +107,7 @@ def test_vizier_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_vizier_service_client_get_transport_class(): @@ -142,44 +121,29 @@ def test_vizier_service_client_get_transport_class(): assert transport == transports.VizierServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), - ( - VizierServiceAsyncClient, - transports.VizierServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - VizierServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(VizierServiceClient), -) -@mock.patch.object( - VizierServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(VizierServiceAsyncClient), -) -def test_vizier_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(VizierServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceClient)) +@mock.patch.object(VizierServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceAsyncClient)) +def test_vizier_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(VizierServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(VizierServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(VizierServiceClient, "get_transport_class") as gtc: + with mock.patch.object(VizierServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -195,7 +159,7 @@ def test_vizier_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -211,7 +175,7 @@ def test_vizier_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -231,15 +195,13 @@ def test_vizier_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -252,52 +214,26 @@ def test_vizier_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "true"), - ( - VizierServiceAsyncClient, - transports.VizierServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "false"), - ( - VizierServiceAsyncClient, - transports.VizierServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - VizierServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(VizierServiceClient), -) -@mock.patch.object( - VizierServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(VizierServiceAsyncClient), -) + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "true"), + (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "false"), + (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(VizierServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceClient)) +@mock.patch.object(VizierServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_vizier_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_vizier_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) @@ -320,18 +256,10 @@ def test_vizier_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, - ): - with mock.patch( - "google.auth.transport.mtls.default_client_cert_source", - return_value=client_cert_source_callback, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -352,14 +280,9 @@ def test_vizier_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -373,23 +296,16 @@ def test_vizier_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), - ( - VizierServiceAsyncClient, - transports.VizierServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_vizier_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_vizier_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -402,24 +318,16 @@ def test_vizier_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), - ( - VizierServiceAsyncClient, - transports.VizierServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_vizier_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_vizier_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -434,12 +342,10 @@ def test_vizier_service_client_client_options_credentials_file( def test_vizier_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = VizierServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -452,11 +358,10 @@ def test_vizier_service_client_client_options_from_dict(): ) -def test_create_study( - transport: str = "grpc", request_type=vizier_service.CreateStudyRequest -): +def test_create_study(transport: str = 'grpc', request_type=vizier_service.CreateStudyRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -464,13 +369,19 @@ def test_create_study( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_study), "__call__") as call: + with mock.patch.object( + type(client.transport.create_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_study.Study( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=gca_study.Study.State.ACTIVE, - inactive_reason="inactive_reason_value", + + inactive_reason='inactive_reason_value', + ) response = client.create_study(request) @@ -485,13 +396,13 @@ def test_create_study( assert isinstance(response, gca_study.Study) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == gca_study.Study.State.ACTIVE - assert response.inactive_reason == "inactive_reason_value" + assert response.inactive_reason == 'inactive_reason_value' def test_create_study_from_dict(): @@ -502,24 +413,25 @@ def test_create_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_study), "__call__") as call: + with mock.patch.object( + type(client.transport.create_study), + '__call__') as call: client.create_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CreateStudyRequest() - @pytest.mark.asyncio -async def test_create_study_async( - transport: str = "grpc_asyncio", request_type=vizier_service.CreateStudyRequest -): +async def test_create_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CreateStudyRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -527,16 +439,16 @@ async def test_create_study_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_study), "__call__") as call: + with mock.patch.object( + type(client.transport.create_study), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_study.Study( - name="name_value", - display_name="display_name_value", - state=gca_study.Study.State.ACTIVE, - inactive_reason="inactive_reason_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_study.Study( + name='name_value', + display_name='display_name_value', + state=gca_study.Study.State.ACTIVE, + inactive_reason='inactive_reason_value', + )) response = await client.create_study(request) @@ -549,13 +461,13 @@ async def test_create_study_async( # Establish that the response is the type that we expect. assert isinstance(response, gca_study.Study) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == gca_study.Study.State.ACTIVE - assert response.inactive_reason == "inactive_reason_value" + assert response.inactive_reason == 'inactive_reason_value' @pytest.mark.asyncio @@ -564,15 +476,19 @@ async def test_create_study_async_from_dict(): def test_create_study_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateStudyRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_study), "__call__") as call: + with mock.patch.object( + type(client.transport.create_study), + '__call__') as call: call.return_value = gca_study.Study() client.create_study(request) @@ -584,20 +500,27 @@ def test_create_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_study_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateStudyRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_study), "__call__") as call: + with mock.patch.object( + type(client.transport.create_study), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_study.Study()) await client.create_study(request) @@ -609,21 +532,29 @@ async def test_create_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_study_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_study), "__call__") as call: + with mock.patch.object( + type(client.transport.create_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_study.Study() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_study( - parent="parent_value", study=gca_study.Study(name="name_value"), + parent='parent_value', + study=gca_study.Study(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -631,30 +562,36 @@ def test_create_study_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].study == gca_study.Study(name="name_value") + assert args[0].study == gca_study.Study(name='name_value') def test_create_study_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_study( vizier_service.CreateStudyRequest(), - parent="parent_value", - study=gca_study.Study(name="name_value"), + parent='parent_value', + study=gca_study.Study(name='name_value'), ) @pytest.mark.asyncio async def test_create_study_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_study), "__call__") as call: + with mock.patch.object( + type(client.transport.create_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_study.Study() @@ -662,7 +599,8 @@ async def test_create_study_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_study( - parent="parent_value", study=gca_study.Study(name="name_value"), + parent='parent_value', + study=gca_study.Study(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -670,30 +608,31 @@ async def test_create_study_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].study == gca_study.Study(name="name_value") + assert args[0].study == gca_study.Study(name='name_value') @pytest.mark.asyncio async def test_create_study_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_study( vizier_service.CreateStudyRequest(), - parent="parent_value", - study=gca_study.Study(name="name_value"), + parent='parent_value', + study=gca_study.Study(name='name_value'), ) -def test_get_study( - transport: str = "grpc", request_type=vizier_service.GetStudyRequest -): +def test_get_study(transport: str = 'grpc', request_type=vizier_service.GetStudyRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -701,13 +640,19 @@ def test_get_study( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_study), "__call__") as call: + with mock.patch.object( + type(client.transport.get_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Study( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=study.Study.State.ACTIVE, - inactive_reason="inactive_reason_value", + + inactive_reason='inactive_reason_value', + ) response = client.get_study(request) @@ -722,13 +667,13 @@ def test_get_study( assert isinstance(response, study.Study) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == "inactive_reason_value" + assert response.inactive_reason == 'inactive_reason_value' def test_get_study_from_dict(): @@ -739,24 +684,25 @@ def test_get_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_study), "__call__") as call: + with mock.patch.object( + type(client.transport.get_study), + '__call__') as call: client.get_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.GetStudyRequest() - @pytest.mark.asyncio -async def test_get_study_async( - transport: str = "grpc_asyncio", request_type=vizier_service.GetStudyRequest -): +async def test_get_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.GetStudyRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -764,16 +710,16 @@ async def test_get_study_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_study), "__call__") as call: + with mock.patch.object( + type(client.transport.get_study), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - study.Study( - name="name_value", - display_name="display_name_value", - state=study.Study.State.ACTIVE, - inactive_reason="inactive_reason_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study( + name='name_value', + display_name='display_name_value', + state=study.Study.State.ACTIVE, + inactive_reason='inactive_reason_value', + )) response = await client.get_study(request) @@ -786,13 +732,13 @@ async def test_get_study_async( # Establish that the response is the type that we expect. assert isinstance(response, study.Study) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == "inactive_reason_value" + assert response.inactive_reason == 'inactive_reason_value' @pytest.mark.asyncio @@ -801,15 +747,19 @@ async def test_get_study_async_from_dict(): def test_get_study_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetStudyRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_study), "__call__") as call: + with mock.patch.object( + type(client.transport.get_study), + '__call__') as call: call.return_value = study.Study() client.get_study(request) @@ -821,20 +771,27 @@ def test_get_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_study_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetStudyRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_study), "__call__") as call: + with mock.patch.object( + type(client.transport.get_study), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) await client.get_study(request) @@ -846,79 +803,99 @@ async def test_get_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_study_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_study), "__call__") as call: + with mock.patch.object( + type(client.transport.get_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Study() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_study(name="name_value",) + client.get_study( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_study_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_study( - vizier_service.GetStudyRequest(), name="name_value", + vizier_service.GetStudyRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_study_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_study), "__call__") as call: + with mock.patch.object( + type(client.transport.get_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Study() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_study(name="name_value",) + response = await client.get_study( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_study_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_study( - vizier_service.GetStudyRequest(), name="name_value", + vizier_service.GetStudyRequest(), + name='name_value', ) -def test_list_studies( - transport: str = "grpc", request_type=vizier_service.ListStudiesRequest -): +def test_list_studies(transport: str = 'grpc', request_type=vizier_service.ListStudiesRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -926,10 +903,13 @@ def test_list_studies( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListStudiesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_studies(request) @@ -944,7 +924,7 @@ def test_list_studies( assert isinstance(response, pagers.ListStudiesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_studies_from_dict(): @@ -955,24 +935,25 @@ def test_list_studies_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: client.list_studies() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.ListStudiesRequest() - @pytest.mark.asyncio -async def test_list_studies_async( - transport: str = "grpc_asyncio", request_type=vizier_service.ListStudiesRequest -): +async def test_list_studies_async(transport: str = 'grpc_asyncio', request_type=vizier_service.ListStudiesRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -980,11 +961,13 @@ async def test_list_studies_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListStudiesResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListStudiesResponse( + next_page_token='next_page_token_value', + )) response = await client.list_studies(request) @@ -997,7 +980,7 @@ async def test_list_studies_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListStudiesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -1006,15 +989,19 @@ async def test_list_studies_async_from_dict(): def test_list_studies_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListStudiesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: call.return_value = vizier_service.ListStudiesResponse() client.list_studies(request) @@ -1026,23 +1013,28 @@ def test_list_studies_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_studies_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListStudiesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListStudiesResponse() - ) + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListStudiesResponse()) await client.list_studies(request) @@ -1053,100 +1045,138 @@ async def test_list_studies_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_studies_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListStudiesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_studies(parent="parent_value",) + client.list_studies( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_studies_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_studies( - vizier_service.ListStudiesRequest(), parent="parent_value", + vizier_service.ListStudiesRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_studies_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListStudiesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListStudiesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListStudiesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_studies(parent="parent_value",) + response = await client.list_studies( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_studies_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_studies( - vizier_service.ListStudiesRequest(), parent="parent_value", + vizier_service.ListStudiesRequest(), + parent='parent_value', ) def test_list_studies_pager(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(), study.Study(),], - next_page_token="abc", + studies=[ + study.Study(), + study.Study(), + study.Study(), + ], + next_page_token='abc', ), - vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[study.Study(),], next_page_token="ghi", + studies=[], + next_page_token='def', ), vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(),], + studies=[ + study.Study(), + ], + next_page_token='ghi', + ), + vizier_service.ListStudiesResponse( + studies=[ + study.Study(), + study.Study(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_studies(request={}) @@ -1154,102 +1184,147 @@ def test_list_studies_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, study.Study) for i in results) - + assert all(isinstance(i, study.Study) + for i in results) def test_list_studies_pages(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + with mock.patch.object( + type(client.transport.list_studies), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(), study.Study(),], - next_page_token="abc", + studies=[ + study.Study(), + study.Study(), + study.Study(), + ], + next_page_token='abc', ), - vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[study.Study(),], next_page_token="ghi", + studies=[], + next_page_token='def', ), vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(),], + studies=[ + study.Study(), + ], + next_page_token='ghi', + ), + vizier_service.ListStudiesResponse( + studies=[ + study.Study(), + study.Study(), + ], ), RuntimeError, ) pages = list(client.list_studies(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_studies_async_pager(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_studies), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_studies), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(), study.Study(),], - next_page_token="abc", + studies=[ + study.Study(), + study.Study(), + study.Study(), + ], + next_page_token='abc', ), - vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[study.Study(),], next_page_token="ghi", + studies=[], + next_page_token='def', ), vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(),], + studies=[ + study.Study(), + ], + next_page_token='ghi', + ), + vizier_service.ListStudiesResponse( + studies=[ + study.Study(), + study.Study(), + ], ), RuntimeError, ) async_pager = await client.list_studies(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, study.Study) for i in responses) - + assert all(isinstance(i, study.Study) + for i in responses) @pytest.mark.asyncio async def test_list_studies_async_pages(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_studies), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_studies), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(), study.Study(),], - next_page_token="abc", + studies=[ + study.Study(), + study.Study(), + study.Study(), + ], + next_page_token='abc', + ), + vizier_service.ListStudiesResponse( + studies=[], + next_page_token='def', ), - vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[study.Study(),], next_page_token="ghi", + studies=[ + study.Study(), + ], + next_page_token='ghi', ), vizier_service.ListStudiesResponse( - studies=[study.Study(), study.Study(),], + studies=[ + study.Study(), + study.Study(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_studies(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_study( - transport: str = "grpc", request_type=vizier_service.DeleteStudyRequest -): +def test_delete_study(transport: str = 'grpc', request_type=vizier_service.DeleteStudyRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1257,7 +1332,9 @@ def test_delete_study( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1281,24 +1358,25 @@ def test_delete_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_study), + '__call__') as call: client.delete_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.DeleteStudyRequest() - @pytest.mark.asyncio -async def test_delete_study_async( - transport: str = "grpc_asyncio", request_type=vizier_service.DeleteStudyRequest -): +async def test_delete_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.DeleteStudyRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1306,7 +1384,9 @@ async def test_delete_study_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1328,15 +1408,19 @@ async def test_delete_study_async_from_dict(): def test_delete_study_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteStudyRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_study), + '__call__') as call: call.return_value = None client.delete_study(request) @@ -1348,20 +1432,27 @@ def test_delete_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_study_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteStudyRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_study), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.delete_study(request) @@ -1373,79 +1464,99 @@ async def test_delete_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_study_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_study(name="name_value",) + client.delete_study( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_study_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_study( - vizier_service.DeleteStudyRequest(), name="name_value", + vizier_service.DeleteStudyRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_study_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_study(name="name_value",) + response = await client.delete_study( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_study_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_study( - vizier_service.DeleteStudyRequest(), name="name_value", + vizier_service.DeleteStudyRequest(), + name='name_value', ) -def test_lookup_study( - transport: str = "grpc", request_type=vizier_service.LookupStudyRequest -): +def test_lookup_study(transport: str = 'grpc', request_type=vizier_service.LookupStudyRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1453,13 +1564,19 @@ def test_lookup_study( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + with mock.patch.object( + type(client.transport.lookup_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Study( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=study.Study.State.ACTIVE, - inactive_reason="inactive_reason_value", + + inactive_reason='inactive_reason_value', + ) response = client.lookup_study(request) @@ -1474,13 +1591,13 @@ def test_lookup_study( assert isinstance(response, study.Study) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == "inactive_reason_value" + assert response.inactive_reason == 'inactive_reason_value' def test_lookup_study_from_dict(): @@ -1491,24 +1608,25 @@ def test_lookup_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + with mock.patch.object( + type(client.transport.lookup_study), + '__call__') as call: client.lookup_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.LookupStudyRequest() - @pytest.mark.asyncio -async def test_lookup_study_async( - transport: str = "grpc_asyncio", request_type=vizier_service.LookupStudyRequest -): +async def test_lookup_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.LookupStudyRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1516,16 +1634,16 @@ async def test_lookup_study_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + with mock.patch.object( + type(client.transport.lookup_study), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - study.Study( - name="name_value", - display_name="display_name_value", - state=study.Study.State.ACTIVE, - inactive_reason="inactive_reason_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study( + name='name_value', + display_name='display_name_value', + state=study.Study.State.ACTIVE, + inactive_reason='inactive_reason_value', + )) response = await client.lookup_study(request) @@ -1538,13 +1656,13 @@ async def test_lookup_study_async( # Establish that the response is the type that we expect. assert isinstance(response, study.Study) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == "inactive_reason_value" + assert response.inactive_reason == 'inactive_reason_value' @pytest.mark.asyncio @@ -1553,15 +1671,19 @@ async def test_lookup_study_async_from_dict(): def test_lookup_study_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.LookupStudyRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + with mock.patch.object( + type(client.transport.lookup_study), + '__call__') as call: call.return_value = study.Study() client.lookup_study(request) @@ -1573,20 +1695,27 @@ def test_lookup_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_lookup_study_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.LookupStudyRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + with mock.patch.object( + type(client.transport.lookup_study), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) await client.lookup_study(request) @@ -1598,79 +1727,99 @@ async def test_lookup_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_lookup_study_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + with mock.patch.object( + type(client.transport.lookup_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Study() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.lookup_study(parent="parent_value",) + client.lookup_study( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_lookup_study_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.lookup_study( - vizier_service.LookupStudyRequest(), parent="parent_value", + vizier_service.LookupStudyRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_lookup_study_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + with mock.patch.object( + type(client.transport.lookup_study), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Study() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.lookup_study(parent="parent_value",) + response = await client.lookup_study( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_lookup_study_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.lookup_study( - vizier_service.LookupStudyRequest(), parent="parent_value", + vizier_service.LookupStudyRequest(), + parent='parent_value', ) -def test_suggest_trials( - transport: str = "grpc", request_type=vizier_service.SuggestTrialsRequest -): +def test_suggest_trials(transport: str = 'grpc', request_type=vizier_service.SuggestTrialsRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1678,9 +1827,11 @@ def test_suggest_trials( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.suggest_trials), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.suggest_trials(request) @@ -1702,24 +1853,25 @@ def test_suggest_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.suggest_trials), + '__call__') as call: client.suggest_trials() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.SuggestTrialsRequest() - @pytest.mark.asyncio -async def test_suggest_trials_async( - transport: str = "grpc_asyncio", request_type=vizier_service.SuggestTrialsRequest -): +async def test_suggest_trials_async(transport: str = 'grpc_asyncio', request_type=vizier_service.SuggestTrialsRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1727,10 +1879,12 @@ async def test_suggest_trials_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.suggest_trials), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.suggest_trials(request) @@ -1751,16 +1905,20 @@ async def test_suggest_trials_async_from_dict(): def test_suggest_trials_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.SuggestTrialsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.suggest_trials), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.suggest_trials(request) @@ -1771,23 +1929,28 @@ def test_suggest_trials_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_suggest_trials_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.SuggestTrialsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.suggest_trials), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.suggest_trials(request) @@ -1798,14 +1961,16 @@ async def test_suggest_trials_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] -def test_create_trial( - transport: str = "grpc", request_type=vizier_service.CreateTrialRequest -): +def test_create_trial(transport: str = 'grpc', request_type=vizier_service.CreateTrialRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1813,13 +1978,23 @@ def test_create_trial( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.create_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name="name_value", - id="id_value", + name='name_value', + + id='id_value', + state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", + + client_id='client_id_value', + + infeasible_reason='infeasible_reason_value', + + custom_job='custom_job_value', + ) response = client.create_trial(request) @@ -1834,13 +2009,17 @@ def test_create_trial( assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' def test_create_trial_from_dict(): @@ -1851,24 +2030,25 @@ def test_create_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.create_trial), + '__call__') as call: client.create_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CreateTrialRequest() - @pytest.mark.asyncio -async def test_create_trial_async( - transport: str = "grpc_asyncio", request_type=vizier_service.CreateTrialRequest -): +async def test_create_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CreateTrialRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1876,16 +2056,18 @@ async def test_create_trial_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.create_trial), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - study.Trial( - name="name_value", - id="id_value", - state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( + name='name_value', + id='id_value', + state=study.Trial.State.REQUESTED, + client_id='client_id_value', + infeasible_reason='infeasible_reason_value', + custom_job='custom_job_value', + )) response = await client.create_trial(request) @@ -1898,13 +2080,17 @@ async def test_create_trial_async( # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' @pytest.mark.asyncio @@ -1913,15 +2099,19 @@ async def test_create_trial_async_from_dict(): def test_create_trial_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateTrialRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.create_trial), + '__call__') as call: call.return_value = study.Trial() client.create_trial(request) @@ -1933,20 +2123,27 @@ def test_create_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_trial_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateTrialRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.create_trial), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.create_trial(request) @@ -1958,21 +2155,29 @@ async def test_create_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_trial_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.create_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_trial( - parent="parent_value", trial=study.Trial(name="name_value"), + parent='parent_value', + trial=study.Trial(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -1980,30 +2185,36 @@ def test_create_trial_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].trial == study.Trial(name="name_value") + assert args[0].trial == study.Trial(name='name_value') def test_create_trial_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_trial( vizier_service.CreateTrialRequest(), - parent="parent_value", - trial=study.Trial(name="name_value"), + parent='parent_value', + trial=study.Trial(name='name_value'), ) @pytest.mark.asyncio async def test_create_trial_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.create_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() @@ -2011,7 +2222,8 @@ async def test_create_trial_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_trial( - parent="parent_value", trial=study.Trial(name="name_value"), + parent='parent_value', + trial=study.Trial(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -2019,30 +2231,31 @@ async def test_create_trial_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].trial == study.Trial(name="name_value") + assert args[0].trial == study.Trial(name='name_value') @pytest.mark.asyncio async def test_create_trial_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_trial( vizier_service.CreateTrialRequest(), - parent="parent_value", - trial=study.Trial(name="name_value"), + parent='parent_value', + trial=study.Trial(name='name_value'), ) -def test_get_trial( - transport: str = "grpc", request_type=vizier_service.GetTrialRequest -): +def test_get_trial(transport: str = 'grpc', request_type=vizier_service.GetTrialRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2050,13 +2263,23 @@ def test_get_trial( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.get_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name="name_value", - id="id_value", + name='name_value', + + id='id_value', + state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", + + client_id='client_id_value', + + infeasible_reason='infeasible_reason_value', + + custom_job='custom_job_value', + ) response = client.get_trial(request) @@ -2071,13 +2294,17 @@ def test_get_trial( assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' def test_get_trial_from_dict(): @@ -2088,24 +2315,25 @@ def test_get_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.get_trial), + '__call__') as call: client.get_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.GetTrialRequest() - @pytest.mark.asyncio -async def test_get_trial_async( - transport: str = "grpc_asyncio", request_type=vizier_service.GetTrialRequest -): +async def test_get_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.GetTrialRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2113,16 +2341,18 @@ async def test_get_trial_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.get_trial), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - study.Trial( - name="name_value", - id="id_value", - state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( + name='name_value', + id='id_value', + state=study.Trial.State.REQUESTED, + client_id='client_id_value', + infeasible_reason='infeasible_reason_value', + custom_job='custom_job_value', + )) response = await client.get_trial(request) @@ -2135,13 +2365,17 @@ async def test_get_trial_async( # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' @pytest.mark.asyncio @@ -2150,15 +2384,19 @@ async def test_get_trial_async_from_dict(): def test_get_trial_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.get_trial), + '__call__') as call: call.return_value = study.Trial() client.get_trial(request) @@ -2170,20 +2408,27 @@ def test_get_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_trial_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.get_trial), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.get_trial(request) @@ -2195,79 +2440,99 @@ async def test_get_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_trial_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.get_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_trial(name="name_value",) + client.get_trial( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_trial_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_trial( - vizier_service.GetTrialRequest(), name="name_value", + vizier_service.GetTrialRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_trial_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.get_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_trial(name="name_value",) + response = await client.get_trial( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_trial_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_trial( - vizier_service.GetTrialRequest(), name="name_value", + vizier_service.GetTrialRequest(), + name='name_value', ) -def test_list_trials( - transport: str = "grpc", request_type=vizier_service.ListTrialsRequest -): +def test_list_trials(transport: str = 'grpc', request_type=vizier_service.ListTrialsRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2275,10 +2540,13 @@ def test_list_trials( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListTrialsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_trials(request) @@ -2293,7 +2561,7 @@ def test_list_trials( assert isinstance(response, pagers.ListTrialsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_trials_from_dict(): @@ -2304,24 +2572,25 @@ def test_list_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: client.list_trials() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.ListTrialsRequest() - @pytest.mark.asyncio -async def test_list_trials_async( - transport: str = "grpc_asyncio", request_type=vizier_service.ListTrialsRequest -): +async def test_list_trials_async(transport: str = 'grpc_asyncio', request_type=vizier_service.ListTrialsRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2329,11 +2598,13 @@ async def test_list_trials_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListTrialsResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListTrialsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_trials(request) @@ -2346,7 +2617,7 @@ async def test_list_trials_async( # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrialsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' @pytest.mark.asyncio @@ -2355,15 +2626,19 @@ async def test_list_trials_async_from_dict(): def test_list_trials_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListTrialsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: call.return_value = vizier_service.ListTrialsResponse() client.list_trials(request) @@ -2375,23 +2650,28 @@ def test_list_trials_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_trials_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListTrialsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListTrialsResponse() - ) + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListTrialsResponse()) await client.list_trials(request) @@ -2402,98 +2682,138 @@ async def test_list_trials_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_trials_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListTrialsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_trials(parent="parent_value",) + client.list_trials( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_trials_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_trials( - vizier_service.ListTrialsRequest(), parent="parent_value", + vizier_service.ListTrialsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_trials_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListTrialsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListTrialsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListTrialsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_trials(parent="parent_value",) + response = await client.list_trials( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_trials_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_trials( - vizier_service.ListTrialsRequest(), parent="parent_value", + vizier_service.ListTrialsRequest(), + parent='parent_value', ) def test_list_trials_pager(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[study.Trial(), study.Trial(), study.Trial(),], - next_page_token="abc", + trials=[ + study.Trial(), + study.Trial(), + study.Trial(), + ], + next_page_token='abc', + ), + vizier_service.ListTrialsResponse( + trials=[], + next_page_token='def', ), - vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[study.Trial(),], next_page_token="ghi", + trials=[ + study.Trial(), + ], + next_page_token='ghi', + ), + vizier_service.ListTrialsResponse( + trials=[ + study.Trial(), + study.Trial(), + ], ), - vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_trials(request={}) @@ -2501,96 +2821,147 @@ def test_list_trials_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, study.Trial) for i in results) - + assert all(isinstance(i, study.Trial) + for i in results) def test_list_trials_pages(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + with mock.patch.object( + type(client.transport.list_trials), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[study.Trial(), study.Trial(), study.Trial(),], - next_page_token="abc", + trials=[ + study.Trial(), + study.Trial(), + study.Trial(), + ], + next_page_token='abc', + ), + vizier_service.ListTrialsResponse( + trials=[], + next_page_token='def', + ), + vizier_service.ListTrialsResponse( + trials=[ + study.Trial(), + ], + next_page_token='ghi', ), - vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[study.Trial(),], next_page_token="ghi", + trials=[ + study.Trial(), + study.Trial(), + ], ), - vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) pages = list(client.list_trials(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_trials_async_pager(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_trials), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_trials), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[study.Trial(), study.Trial(), study.Trial(),], - next_page_token="abc", + trials=[ + study.Trial(), + study.Trial(), + study.Trial(), + ], + next_page_token='abc', + ), + vizier_service.ListTrialsResponse( + trials=[], + next_page_token='def', + ), + vizier_service.ListTrialsResponse( + trials=[ + study.Trial(), + ], + next_page_token='ghi', ), - vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[study.Trial(),], next_page_token="ghi", + trials=[ + study.Trial(), + study.Trial(), + ], ), - vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) async_pager = await client.list_trials(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, study.Trial) for i in responses) - + assert all(isinstance(i, study.Trial) + for i in responses) @pytest.mark.asyncio async def test_list_trials_async_pages(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_trials), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_trials), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[study.Trial(), study.Trial(), study.Trial(),], - next_page_token="abc", + trials=[ + study.Trial(), + study.Trial(), + study.Trial(), + ], + next_page_token='abc', + ), + vizier_service.ListTrialsResponse( + trials=[], + next_page_token='def', + ), + vizier_service.ListTrialsResponse( + trials=[ + study.Trial(), + ], + next_page_token='ghi', ), - vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[study.Trial(),], next_page_token="ghi", + trials=[ + study.Trial(), + study.Trial(), + ], ), - vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_trials(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_add_trial_measurement( - transport: str = "grpc", request_type=vizier_service.AddTrialMeasurementRequest -): +def test_add_trial_measurement(transport: str = 'grpc', request_type=vizier_service.AddTrialMeasurementRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2599,14 +2970,22 @@ def test_add_trial_measurement( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), "__call__" - ) as call: + type(client.transport.add_trial_measurement), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name="name_value", - id="id_value", + name='name_value', + + id='id_value', + state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", + + client_id='client_id_value', + + infeasible_reason='infeasible_reason_value', + + custom_job='custom_job_value', + ) response = client.add_trial_measurement(request) @@ -2621,13 +3000,17 @@ def test_add_trial_measurement( assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' def test_add_trial_measurement_from_dict(): @@ -2638,27 +3021,25 @@ def test_add_trial_measurement_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), "__call__" - ) as call: + type(client.transport.add_trial_measurement), + '__call__') as call: client.add_trial_measurement() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.AddTrialMeasurementRequest() - @pytest.mark.asyncio -async def test_add_trial_measurement_async( - transport: str = "grpc_asyncio", - request_type=vizier_service.AddTrialMeasurementRequest, -): +async def test_add_trial_measurement_async(transport: str = 'grpc_asyncio', request_type=vizier_service.AddTrialMeasurementRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2667,17 +3048,17 @@ async def test_add_trial_measurement_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), "__call__" - ) as call: + type(client.transport.add_trial_measurement), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - study.Trial( - name="name_value", - id="id_value", - state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( + name='name_value', + id='id_value', + state=study.Trial.State.REQUESTED, + client_id='client_id_value', + infeasible_reason='infeasible_reason_value', + custom_job='custom_job_value', + )) response = await client.add_trial_measurement(request) @@ -2690,13 +3071,17 @@ async def test_add_trial_measurement_async( # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' @pytest.mark.asyncio @@ -2705,17 +3090,19 @@ async def test_add_trial_measurement_async_from_dict(): def test_add_trial_measurement_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.AddTrialMeasurementRequest() - request.trial_name = "trial_name/value" + request.trial_name = 'trial_name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), "__call__" - ) as call: + type(client.transport.add_trial_measurement), + '__call__') as call: call.return_value = study.Trial() client.add_trial_measurement(request) @@ -2727,22 +3114,27 @@ def test_add_trial_measurement_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'trial_name=trial_name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_add_trial_measurement_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.AddTrialMeasurementRequest() - request.trial_name = "trial_name/value" + request.trial_name = 'trial_name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), "__call__" - ) as call: + type(client.transport.add_trial_measurement), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.add_trial_measurement(request) @@ -2754,14 +3146,16 @@ async def test_add_trial_measurement_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'trial_name=trial_name/value', + ) in kw['metadata'] -def test_complete_trial( - transport: str = "grpc", request_type=vizier_service.CompleteTrialRequest -): +def test_complete_trial(transport: str = 'grpc', request_type=vizier_service.CompleteTrialRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2769,13 +3163,23 @@ def test_complete_trial( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.complete_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name="name_value", - id="id_value", + name='name_value', + + id='id_value', + state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", + + client_id='client_id_value', + + infeasible_reason='infeasible_reason_value', + + custom_job='custom_job_value', + ) response = client.complete_trial(request) @@ -2790,13 +3194,17 @@ def test_complete_trial( assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' def test_complete_trial_from_dict(): @@ -2807,24 +3215,25 @@ def test_complete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.complete_trial), + '__call__') as call: client.complete_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CompleteTrialRequest() - @pytest.mark.asyncio -async def test_complete_trial_async( - transport: str = "grpc_asyncio", request_type=vizier_service.CompleteTrialRequest -): +async def test_complete_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CompleteTrialRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2832,16 +3241,18 @@ async def test_complete_trial_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.complete_trial), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - study.Trial( - name="name_value", - id="id_value", - state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( + name='name_value', + id='id_value', + state=study.Trial.State.REQUESTED, + client_id='client_id_value', + infeasible_reason='infeasible_reason_value', + custom_job='custom_job_value', + )) response = await client.complete_trial(request) @@ -2854,13 +3265,17 @@ async def test_complete_trial_async( # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' @pytest.mark.asyncio @@ -2869,15 +3284,19 @@ async def test_complete_trial_async_from_dict(): def test_complete_trial_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CompleteTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.complete_trial), + '__call__') as call: call.return_value = study.Trial() client.complete_trial(request) @@ -2889,20 +3308,27 @@ def test_complete_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_complete_trial_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CompleteTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.complete_trial), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.complete_trial(request) @@ -2914,14 +3340,16 @@ async def test_complete_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] -def test_delete_trial( - transport: str = "grpc", request_type=vizier_service.DeleteTrialRequest -): +def test_delete_trial(transport: str = 'grpc', request_type=vizier_service.DeleteTrialRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2929,7 +3357,9 @@ def test_delete_trial( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2953,24 +3383,25 @@ def test_delete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_trial), + '__call__') as call: client.delete_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.DeleteTrialRequest() - @pytest.mark.asyncio -async def test_delete_trial_async( - transport: str = "grpc_asyncio", request_type=vizier_service.DeleteTrialRequest -): +async def test_delete_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.DeleteTrialRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2978,7 +3409,9 @@ async def test_delete_trial_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -3000,15 +3433,19 @@ async def test_delete_trial_async_from_dict(): def test_delete_trial_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_trial), + '__call__') as call: call.return_value = None client.delete_trial(request) @@ -3020,20 +3457,27 @@ def test_delete_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_trial_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_trial), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.delete_trial(request) @@ -3045,80 +3489,99 @@ async def test_delete_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_trial_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_trial(name="name_value",) + client.delete_trial( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_trial_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_trial( - vizier_service.DeleteTrialRequest(), name="name_value", + vizier_service.DeleteTrialRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_trial_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_trial(name="name_value",) + response = await client.delete_trial( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_trial_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_trial( - vizier_service.DeleteTrialRequest(), name="name_value", + vizier_service.DeleteTrialRequest(), + name='name_value', ) -def test_check_trial_early_stopping_state( - transport: str = "grpc", - request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, -): +def test_check_trial_early_stopping_state(transport: str = 'grpc', request_type=vizier_service.CheckTrialEarlyStoppingStateRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3127,10 +3590,10 @@ def test_check_trial_early_stopping_state( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), "__call__" - ) as call: + type(client.transport.check_trial_early_stopping_state), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.check_trial_early_stopping_state(request) @@ -3152,27 +3615,25 @@ def test_check_trial_early_stopping_state_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), "__call__" - ) as call: + type(client.transport.check_trial_early_stopping_state), + '__call__') as call: client.check_trial_early_stopping_state() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CheckTrialEarlyStoppingStateRequest() - @pytest.mark.asyncio -async def test_check_trial_early_stopping_state_async( - transport: str = "grpc_asyncio", - request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, -): +async def test_check_trial_early_stopping_state_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CheckTrialEarlyStoppingStateRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3181,11 +3642,11 @@ async def test_check_trial_early_stopping_state_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), "__call__" - ) as call: + type(client.transport.check_trial_early_stopping_state), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.check_trial_early_stopping_state(request) @@ -3206,18 +3667,20 @@ async def test_check_trial_early_stopping_state_async_from_dict(): def test_check_trial_early_stopping_state_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CheckTrialEarlyStoppingStateRequest() - request.trial_name = "trial_name/value" + request.trial_name = 'trial_name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.check_trial_early_stopping_state), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.check_trial_early_stopping_state(request) @@ -3228,25 +3691,28 @@ def test_check_trial_early_stopping_state_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'trial_name=trial_name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_check_trial_early_stopping_state_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CheckTrialEarlyStoppingStateRequest() - request.trial_name = "trial_name/value" + request.trial_name = 'trial_name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.check_trial_early_stopping_state), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.check_trial_early_stopping_state(request) @@ -3257,14 +3723,16 @@ async def test_check_trial_early_stopping_state_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'trial_name=trial_name/value', + ) in kw['metadata'] -def test_stop_trial( - transport: str = "grpc", request_type=vizier_service.StopTrialRequest -): +def test_stop_trial(transport: str = 'grpc', request_type=vizier_service.StopTrialRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3272,13 +3740,23 @@ def test_stop_trial( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.stop_trial), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name="name_value", - id="id_value", + name='name_value', + + id='id_value', + state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", + + client_id='client_id_value', + + infeasible_reason='infeasible_reason_value', + + custom_job='custom_job_value', + ) response = client.stop_trial(request) @@ -3293,13 +3771,17 @@ def test_stop_trial( assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' def test_stop_trial_from_dict(): @@ -3310,24 +3792,25 @@ def test_stop_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.stop_trial), + '__call__') as call: client.stop_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.StopTrialRequest() - @pytest.mark.asyncio -async def test_stop_trial_async( - transport: str = "grpc_asyncio", request_type=vizier_service.StopTrialRequest -): +async def test_stop_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.StopTrialRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3335,16 +3818,18 @@ async def test_stop_trial_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.stop_trial), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - study.Trial( - name="name_value", - id="id_value", - state=study.Trial.State.REQUESTED, - custom_job="custom_job_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( + name='name_value', + id='id_value', + state=study.Trial.State.REQUESTED, + client_id='client_id_value', + infeasible_reason='infeasible_reason_value', + custom_job='custom_job_value', + )) response = await client.stop_trial(request) @@ -3357,13 +3842,17 @@ async def test_stop_trial_async( # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.id == "id_value" + assert response.id == 'id_value' assert response.state == study.Trial.State.REQUESTED - assert response.custom_job == "custom_job_value" + assert response.client_id == 'client_id_value' + + assert response.infeasible_reason == 'infeasible_reason_value' + + assert response.custom_job == 'custom_job_value' @pytest.mark.asyncio @@ -3372,15 +3861,19 @@ async def test_stop_trial_async_from_dict(): def test_stop_trial_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.StopTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.stop_trial), + '__call__') as call: call.return_value = study.Trial() client.stop_trial(request) @@ -3392,20 +3885,27 @@ def test_stop_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_stop_trial_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.StopTrialRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + with mock.patch.object( + type(client.transport.stop_trial), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.stop_trial(request) @@ -3417,14 +3917,16 @@ async def test_stop_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] -def test_list_optimal_trials( - transport: str = "grpc", request_type=vizier_service.ListOptimalTrialsRequest -): +def test_list_optimal_trials(transport: str = 'grpc', request_type=vizier_service.ListOptimalTrialsRequest): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3433,10 +3935,11 @@ def test_list_optimal_trials( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), "__call__" - ) as call: + type(client.transport.list_optimal_trials), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = vizier_service.ListOptimalTrialsResponse() + call.return_value = vizier_service.ListOptimalTrialsResponse( + ) response = client.list_optimal_trials(request) @@ -3459,27 +3962,25 @@ def test_list_optimal_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), "__call__" - ) as call: + type(client.transport.list_optimal_trials), + '__call__') as call: client.list_optimal_trials() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.ListOptimalTrialsRequest() - @pytest.mark.asyncio -async def test_list_optimal_trials_async( - transport: str = "grpc_asyncio", - request_type=vizier_service.ListOptimalTrialsRequest, -): +async def test_list_optimal_trials_async(transport: str = 'grpc_asyncio', request_type=vizier_service.ListOptimalTrialsRequest): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3488,12 +3989,11 @@ async def test_list_optimal_trials_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), "__call__" - ) as call: + type(client.transport.list_optimal_trials), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListOptimalTrialsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListOptimalTrialsResponse( + )) response = await client.list_optimal_trials(request) @@ -3513,17 +4013,19 @@ async def test_list_optimal_trials_async_from_dict(): def test_list_optimal_trials_field_headers(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListOptimalTrialsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), "__call__" - ) as call: + type(client.transport.list_optimal_trials), + '__call__') as call: call.return_value = vizier_service.ListOptimalTrialsResponse() client.list_optimal_trials(request) @@ -3535,25 +4037,28 @@ def test_list_optimal_trials_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_optimal_trials_field_headers_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListOptimalTrialsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListOptimalTrialsResponse() - ) + type(client.transport.list_optimal_trials), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListOptimalTrialsResponse()) await client.list_optimal_trials(request) @@ -3564,77 +4069,92 @@ async def test_list_optimal_trials_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_optimal_trials_flattened(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), "__call__" - ) as call: + type(client.transport.list_optimal_trials), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListOptimalTrialsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_optimal_trials(parent="parent_value",) + client.list_optimal_trials( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_optimal_trials_flattened_error(): - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_optimal_trials( - vizier_service.ListOptimalTrialsRequest(), parent="parent_value", + vizier_service.ListOptimalTrialsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_optimal_trials_flattened_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), "__call__" - ) as call: + type(client.transport.list_optimal_trials), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListOptimalTrialsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - vizier_service.ListOptimalTrialsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListOptimalTrialsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_optimal_trials(parent="parent_value",) + response = await client.list_optimal_trials( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_optimal_trials_flattened_error_async(): - client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_optimal_trials( - vizier_service.ListOptimalTrialsRequest(), parent="parent_value", + vizier_service.ListOptimalTrialsRequest(), + parent='parent_value', ) @@ -3645,7 +4165,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3664,7 +4185,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = VizierServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -3692,16 +4214,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.VizierServiceGrpcTransport, - transports.VizierServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3709,8 +4228,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.VizierServiceGrpcTransport,) + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.VizierServiceGrpcTransport, + ) def test_vizier_service_base_transport_error(): @@ -3718,15 +4242,13 @@ def test_vizier_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.VizierServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_vizier_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.VizierServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3735,22 +4257,22 @@ def test_vizier_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_study", - "get_study", - "list_studies", - "delete_study", - "lookup_study", - "suggest_trials", - "create_trial", - "get_trial", - "list_trials", - "add_trial_measurement", - "complete_trial", - "delete_trial", - "check_trial_early_stopping_state", - "stop_trial", - "list_optimal_trials", - ) + 'create_study', + 'get_study', + 'list_studies', + 'delete_study', + 'lookup_study', + 'suggest_trials', + 'create_trial', + 'get_trial', + 'list_trials', + 'add_trial_measurement', + 'complete_trial', + 'delete_trial', + 'check_trial_early_stopping_state', + 'stop_trial', + 'list_optimal_trials', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3763,28 +4285,23 @@ def test_vizier_service_base_transport(): def test_vizier_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.VizierServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_vizier_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.VizierServiceTransport() @@ -3793,11 +4310,11 @@ def test_vizier_service_base_transport_with_adc(): def test_vizier_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) VizierServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -3805,25 +4322,19 @@ def test_vizier_service_auth_adc(): def test_vizier_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.VizierServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.VizierServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) -@pytest.mark.parametrize( - "transport_class", - [ - transports.VizierServiceGrpcTransport, - transports.VizierServiceGrpcAsyncIOTransport, - ], -) -def test_vizier_service_grpc_transport_client_cert_source_for_mtls(transport_class): +@pytest.mark.parametrize("transport_class", [transports.VizierServiceGrpcTransport, transports.VizierServiceGrpcAsyncIOTransport]) +def test_vizier_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3832,13 +4343,15 @@ def test_vizier_service_grpc_transport_client_cert_source_for_mtls(transport_cla transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds, + ssl_channel_credentials=mock_ssl_channel_creds ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3853,40 +4366,38 @@ def test_vizier_service_grpc_transport_client_cert_source_for_mtls(transport_cla with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback, + client_cert_source_for_mtls=client_cert_source_callback ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, private_key=expected_key + certificate_chain=expected_cert, + private_key=expected_key ) def test_vizier_service_host_no_port(): client = VizierServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_vizier_service_host_with_port(): client = VizierServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_vizier_service_grpc_transport_channel(): - channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.VizierServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3894,11 +4405,12 @@ def test_vizier_service_grpc_transport_channel(): def test_vizier_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.VizierServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3907,20 +4419,12 @@ def test_vizier_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.VizierServiceGrpcTransport, - transports.VizierServiceGrpcAsyncIOTransport, - ], -) -def test_vizier_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: +@pytest.mark.parametrize("transport_class", [transports.VizierServiceGrpcTransport, transports.VizierServiceGrpcAsyncIOTransport]) +def test_vizier_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3929,7 +4433,7 @@ def test_vizier_service_transport_channel_mtls_with_client_cert_source(transport cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3945,7 +4449,9 @@ def test_vizier_service_transport_channel_mtls_with_client_cert_source(transport "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3959,23 +4465,17 @@ def test_vizier_service_transport_channel_mtls_with_client_cert_source(transport # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize( - "transport_class", - [ - transports.VizierServiceGrpcTransport, - transports.VizierServiceGrpcAsyncIOTransport, - ], -) -def test_vizier_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.VizierServiceGrpcTransport, transports.VizierServiceGrpcAsyncIOTransport]) +def test_vizier_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel" - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3992,7 +4492,9 @@ def test_vizier_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -4005,12 +4507,16 @@ def test_vizier_service_transport_channel_mtls_with_adc(transport_class): def test_vizier_service_grpc_lro_client(): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -4018,12 +4524,16 @@ def test_vizier_service_grpc_lro_client(): def test_vizier_service_grpc_lro_async_client(): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -4034,18 +4544,17 @@ def test_custom_job_path(): location = "clam" custom_job = "whelk" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) actual = VizierServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "custom_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "custom_job": "nudibranch", + } path = VizierServiceClient.custom_job_path(**expected) @@ -4053,24 +4562,22 @@ def test_parse_custom_job_path(): actual = VizierServiceClient.parse_custom_job_path(path) assert expected == actual - def test_study_path(): project = "cuttlefish" location = "mussel" study = "winkle" - expected = "projects/{project}/locations/{location}/studies/{study}".format( - project=project, location=location, study=study, - ) + expected = "projects/{project}/locations/{location}/studies/{study}".format(project=project, location=location, study=study, ) actual = VizierServiceClient.study_path(project, location, study) assert expected == actual def test_parse_study_path(): expected = { - "project": "nautilus", - "location": "scallop", - "study": "abalone", + "project": "nautilus", + "location": "scallop", + "study": "abalone", + } path = VizierServiceClient.study_path(**expected) @@ -4078,26 +4585,24 @@ def test_parse_study_path(): actual = VizierServiceClient.parse_study_path(path) assert expected == actual - def test_trial_path(): project = "squid" location = "clam" study = "whelk" trial = "octopus" - expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( - project=project, location=location, study=study, trial=trial, - ) + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) actual = VizierServiceClient.trial_path(project, location, study, trial) assert expected == actual def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", + } path = VizierServiceClient.trial_path(**expected) @@ -4105,20 +4610,18 @@ def test_parse_trial_path(): actual = VizierServiceClient.parse_trial_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = VizierServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "nautilus", + } path = VizierServiceClient.common_billing_account_path(**expected) @@ -4126,18 +4629,18 @@ def test_parse_common_billing_account_path(): actual = VizierServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = VizierServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "abalone", + } path = VizierServiceClient.common_folder_path(**expected) @@ -4145,18 +4648,18 @@ def test_parse_common_folder_path(): actual = VizierServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = VizierServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "clam", + } path = VizierServiceClient.common_organization_path(**expected) @@ -4164,18 +4667,18 @@ def test_parse_common_organization_path(): actual = VizierServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = VizierServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "octopus", + } path = VizierServiceClient.common_project_path(**expected) @@ -4183,22 +4686,20 @@ def test_parse_common_project_path(): actual = VizierServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = VizierServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "cuttlefish", + "location": "mussel", + } path = VizierServiceClient.common_location_path(**expected) @@ -4210,19 +4711,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.VizierServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.VizierServiceTransport, '_prep_wrapped_messages') as prep: client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.VizierServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.VizierServiceTransport, '_prep_wrapped_messages') as prep: transport_class = VizierServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) From 7ad77876003027e358aaee3520f25ca4d758789f Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Wed, 31 Mar 2021 10:55:58 -0400 Subject: [PATCH 02/36] chore: add constraints 3.8/3.9 --- testing/constraints-3.8.txt | 0 testing/constraints-3.9.txt | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 testing/constraints-3.8.txt create mode 100644 testing/constraints-3.9.txt diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt new file mode 100644 index 0000000000..e69de29bb2 From 4223a30df30fc703aec6bcdb4a5467d9dd43bc14 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Wed, 31 Mar 2021 11:46:49 -0400 Subject: [PATCH 03/36] chore: lint --- docs/conf.py | 6 +- .../v1/schema/predict/instance/__init__.py | 54 +- .../v1/schema/predict/instance_v1/__init__.py | 18 +- .../predict/instance_v1/types/__init__.py | 54 +- .../instance_v1/types/image_classification.py | 6 +- .../types/image_object_detection.py | 6 +- .../instance_v1/types/image_segmentation.py | 6 +- .../instance_v1/types/text_classification.py | 6 +- .../instance_v1/types/text_extraction.py | 6 +- .../instance_v1/types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 6 +- .../instance_v1/types/video_classification.py | 6 +- .../types/video_object_tracking.py | 6 +- .../v1/schema/predict/params/__init__.py | 36 +- .../v1/schema/predict/params_v1/__init__.py | 12 +- .../predict/params_v1/types/__init__.py | 36 +- .../params_v1/types/image_classification.py | 6 +- .../params_v1/types/image_object_detection.py | 6 +- .../params_v1/types/image_segmentation.py | 6 +- .../types/video_action_recognition.py | 6 +- .../params_v1/types/video_classification.py | 6 +- .../params_v1/types/video_object_tracking.py | 6 +- .../v1/schema/predict/prediction/__init__.py | 60 +- .../schema/predict/prediction_v1/__init__.py | 20 +- .../predict/prediction_v1/types/__init__.py | 60 +- .../prediction_v1/types/classification.py | 6 +- .../types/image_object_detection.py | 10 +- .../prediction_v1/types/image_segmentation.py | 6 +- .../types/tabular_classification.py | 6 +- .../prediction_v1/types/tabular_regression.py | 6 +- .../prediction_v1/types/text_extraction.py | 6 +- .../prediction_v1/types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 18 +- .../types/video_classification.py | 18 +- .../types/video_object_tracking.py | 43 +- .../schema/trainingjob/definition/__init__.py | 150 +- .../trainingjob/definition_v1/__init__.py | 50 +- .../definition_v1/types/__init__.py | 54 +- .../types/automl_image_classification.py | 26 +- .../types/automl_image_object_detection.py | 26 +- .../types/automl_image_segmentation.py | 26 +- .../definition_v1/types/automl_tables.py | 94 +- .../types/automl_text_classification.py | 11 +- .../types/automl_text_extraction.py | 11 +- .../types/automl_text_sentiment.py | 11 +- .../types/automl_video_action_recognition.py | 16 +- .../types/automl_video_classification.py | 16 +- .../types/automl_video_object_tracking.py | 16 +- .../export_evaluated_data_items_config.py | 6 +- .../schema/predict/instance/__init__.py | 54 +- .../predict/instance_v1beta1/__init__.py | 18 +- .../instance_v1beta1/types/__init__.py | 54 +- .../types/image_classification.py | 6 +- .../types/image_object_detection.py | 6 +- .../types/image_segmentation.py | 6 +- .../types/text_classification.py | 6 +- .../instance_v1beta1/types/text_extraction.py | 6 +- .../instance_v1beta1/types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 6 +- .../types/video_classification.py | 6 +- .../types/video_object_tracking.py | 6 +- .../v1beta1/schema/predict/params/__init__.py | 36 +- .../schema/predict/params_v1beta1/__init__.py | 12 +- .../predict/params_v1beta1/types/__init__.py | 36 +- .../types/image_classification.py | 6 +- .../types/image_object_detection.py | 6 +- .../types/image_segmentation.py | 6 +- .../types/video_action_recognition.py | 6 +- .../types/video_classification.py | 6 +- .../types/video_object_tracking.py | 6 +- .../schema/predict/prediction/__init__.py | 60 +- .../predict/prediction_v1beta1/__init__.py | 20 +- .../prediction_v1beta1/types/__init__.py | 60 +- .../types/classification.py | 6 +- .../types/image_object_detection.py | 10 +- .../types/image_segmentation.py | 6 +- .../types/tabular_classification.py | 6 +- .../types/tabular_regression.py | 6 +- .../types/text_extraction.py | 6 +- .../types/text_sentiment.py | 6 +- .../types/video_action_recognition.py | 18 +- .../types/video_classification.py | 18 +- .../types/video_object_tracking.py | 43 +- .../schema/trainingjob/definition/__init__.py | 150 +- .../definition_v1beta1/__init__.py | 50 +- .../definition_v1beta1/types/__init__.py | 54 +- .../types/automl_image_classification.py | 26 +- .../types/automl_image_object_detection.py | 26 +- .../types/automl_image_segmentation.py | 26 +- .../definition_v1beta1/types/automl_tables.py | 94 +- .../types/automl_text_classification.py | 11 +- .../types/automl_text_extraction.py | 11 +- .../types/automl_text_sentiment.py | 11 +- .../types/automl_video_action_recognition.py | 16 +- .../types/automl_video_classification.py | 16 +- .../types/automl_video_object_tracking.py | 16 +- .../export_evaluated_data_items_config.py | 6 +- google/cloud/aiplatform_v1/__init__.py | 324 +- .../services/dataset_service/__init__.py | 4 +- .../services/dataset_service/async_client.py | 439 +- .../services/dataset_service/client.py | 543 +- .../services/dataset_service/pagers.py | 113 +- .../dataset_service/transports/__init__.py | 10 +- .../dataset_service/transports/base.py | 223 +- .../dataset_service/transports/grpc.py | 212 +- .../transports/grpc_asyncio.py | 220 +- .../services/endpoint_service/__init__.py | 4 +- .../services/endpoint_service/async_client.py | 331 +- .../services/endpoint_service/client.py | 400 +- .../services/endpoint_service/pagers.py | 45 +- .../endpoint_service/transports/__init__.py | 10 +- .../endpoint_service/transports/base.py | 166 +- .../endpoint_service/transports/grpc.py | 163 +- .../transports/grpc_asyncio.py | 173 +- .../services/job_service/__init__.py | 4 +- .../services/job_service/async_client.py | 794 ++- .../services/job_service/client.py | 954 ++- .../services/job_service/pagers.py | 157 +- .../job_service/transports/__init__.py | 10 +- .../services/job_service/transports/base.py | 351 +- .../services/job_service/transports/grpc.py | 391 +- .../job_service/transports/grpc_asyncio.py | 405 +- .../services/migration_service/__init__.py | 4 +- .../migration_service/async_client.py | 151 +- .../services/migration_service/client.py | 282 +- .../services/migration_service/pagers.py | 51 +- .../migration_service/transports/__init__.py | 10 +- .../migration_service/transports/base.py | 78 +- .../migration_service/transports/grpc.py | 96 +- .../transports/grpc_asyncio.py | 96 +- .../services/model_service/__init__.py | 4 +- .../services/model_service/async_client.py | 441 +- .../services/model_service/client.py | 555 +- .../services/model_service/pagers.py | 119 +- .../model_service/transports/__init__.py | 10 +- .../services/model_service/transports/base.py | 210 +- .../services/model_service/transports/grpc.py | 212 +- .../model_service/transports/grpc_asyncio.py | 216 +- .../services/pipeline_service/__init__.py | 4 +- .../services/pipeline_service/async_client.py | 249 +- .../services/pipeline_service/client.py | 329 +- .../services/pipeline_service/pagers.py | 51 +- .../pipeline_service/transports/__init__.py | 10 +- .../pipeline_service/transports/base.py | 120 +- .../pipeline_service/transports/grpc.py | 144 +- .../transports/grpc_asyncio.py | 146 +- .../services/prediction_service/__init__.py | 4 +- .../prediction_service/async_client.py | 108 +- .../services/prediction_service/client.py | 166 +- .../prediction_service/transports/__init__.py | 10 +- .../prediction_service/transports/base.py | 70 +- .../prediction_service/transports/grpc.py | 75 +- .../transports/grpc_asyncio.py | 77 +- .../specialist_pool_service/__init__.py | 4 +- .../specialist_pool_service/async_client.py | 264 +- .../specialist_pool_service/client.py | 309 +- .../specialist_pool_service/pagers.py | 51 +- .../transports/__init__.py | 14 +- .../transports/base.py | 121 +- .../transports/grpc.py | 145 +- .../transports/grpc_asyncio.py | 147 +- google/cloud/aiplatform_v1/types/__init__.py | 368 +- .../aiplatform_v1/types/accelerator_type.py | 5 +- .../cloud/aiplatform_v1/types/annotation.py | 21 +- .../aiplatform_v1/types/annotation_spec.py | 13 +- .../types/batch_prediction_job.py | 99 +- .../aiplatform_v1/types/completion_stats.py | 5 +- .../cloud/aiplatform_v1/types/custom_job.py | 86 +- google/cloud/aiplatform_v1/types/data_item.py | 17 +- .../aiplatform_v1/types/data_labeling_job.py | 71 +- google/cloud/aiplatform_v1/types/dataset.py | 32 +- .../aiplatform_v1/types/dataset_service.py | 102 +- .../aiplatform_v1/types/deployed_model_ref.py | 5 +- .../aiplatform_v1/types/encryption_spec.py | 5 +- google/cloud/aiplatform_v1/types/endpoint.py | 36 +- .../aiplatform_v1/types/endpoint_service.py | 68 +- google/cloud/aiplatform_v1/types/env_var.py | 7 +- .../types/hyperparameter_tuning_job.py | 45 +- google/cloud/aiplatform_v1/types/io.py | 12 +- .../cloud/aiplatform_v1/types/job_service.py | 106 +- google/cloud/aiplatform_v1/types/job_state.py | 5 +- .../aiplatform_v1/types/machine_resources.py | 26 +- .../types/manual_batch_tuning_parameters.py | 5 +- .../types/migratable_resource.py | 37 +- .../aiplatform_v1/types/migration_service.py | 87 +- google/cloud/aiplatform_v1/types/model.py | 59 +- .../aiplatform_v1/types/model_evaluation.py | 13 +- .../types/model_evaluation_slice.py | 18 +- .../aiplatform_v1/types/model_service.py | 98 +- google/cloud/aiplatform_v1/types/operation.py | 23 +- .../aiplatform_v1/types/pipeline_service.py | 26 +- .../aiplatform_v1/types/pipeline_state.py | 5 +- .../aiplatform_v1/types/prediction_service.py | 19 +- .../aiplatform_v1/types/specialist_pool.py | 5 +- .../types/specialist_pool_service.py | 46 +- google/cloud/aiplatform_v1/types/study.py | 129 +- .../aiplatform_v1/types/training_pipeline.py | 82 +- .../types/user_action_reference.py | 9 +- google/cloud/aiplatform_v1beta1/__init__.py | 550 +- .../services/dataset_service/__init__.py | 4 +- .../services/dataset_service/async_client.py | 439 +- .../services/dataset_service/client.py | 543 +- .../services/dataset_service/pagers.py | 113 +- .../dataset_service/transports/__init__.py | 10 +- .../dataset_service/transports/base.py | 223 +- .../dataset_service/transports/grpc.py | 212 +- .../transports/grpc_asyncio.py | 220 +- .../services/endpoint_service/__init__.py | 4 +- .../services/endpoint_service/async_client.py | 331 +- .../services/endpoint_service/client.py | 400 +- .../services/endpoint_service/pagers.py | 45 +- .../endpoint_service/transports/__init__.py | 10 +- .../endpoint_service/transports/base.py | 166 +- .../endpoint_service/transports/grpc.py | 163 +- .../transports/grpc_asyncio.py | 173 +- .../services/job_service/__init__.py | 4 +- .../services/job_service/async_client.py | 1120 ++-- .../services/job_service/client.py | 1386 ++--- .../services/job_service/pagers.py | 278 +- .../job_service/transports/__init__.py | 10 +- .../services/job_service/transports/base.py | 473 +- .../services/job_service/transports/grpc.py | 543 +- .../job_service/transports/grpc_asyncio.py | 563 +- .../services/metadata_service/__init__.py | 4 +- .../services/metadata_service/async_client.py | 997 ++-- .../services/metadata_service/client.py | 1137 ++-- .../services/metadata_service/pagers.py | 185 +- .../metadata_service/transports/__init__.py | 10 +- .../metadata_service/transports/base.py | 453 +- .../metadata_service/transports/grpc.py | 454 +- .../transports/grpc_asyncio.py | 477 +- .../services/migration_service/__init__.py | 4 +- .../migration_service/async_client.py | 151 +- .../services/migration_service/client.py | 282 +- .../services/migration_service/pagers.py | 51 +- .../migration_service/transports/__init__.py | 10 +- .../migration_service/transports/base.py | 78 +- .../migration_service/transports/grpc.py | 96 +- .../transports/grpc_asyncio.py | 96 +- .../services/model_service/__init__.py | 4 +- .../services/model_service/async_client.py | 441 +- .../services/model_service/client.py | 555 +- .../services/model_service/pagers.py | 119 +- .../model_service/transports/__init__.py | 10 +- .../services/model_service/transports/base.py | 214 +- .../services/model_service/transports/grpc.py | 212 +- .../model_service/transports/grpc_asyncio.py | 216 +- .../services/pipeline_service/__init__.py | 4 +- .../services/pipeline_service/async_client.py | 253 +- .../services/pipeline_service/client.py | 333 +- .../services/pipeline_service/pagers.py | 51 +- .../pipeline_service/transports/__init__.py | 10 +- .../pipeline_service/transports/base.py | 124 +- .../pipeline_service/transports/grpc.py | 148 +- .../transports/grpc_asyncio.py | 150 +- .../services/prediction_service/__init__.py | 4 +- .../prediction_service/async_client.py | 148 +- .../services/prediction_service/client.py | 206 +- .../prediction_service/transports/__init__.py | 10 +- .../prediction_service/transports/base.py | 89 +- .../prediction_service/transports/grpc.py | 91 +- .../transports/grpc_asyncio.py | 94 +- .../specialist_pool_service/__init__.py | 4 +- .../specialist_pool_service/async_client.py | 264 +- .../specialist_pool_service/client.py | 309 +- .../specialist_pool_service/pagers.py | 51 +- .../transports/__init__.py | 14 +- .../transports/base.py | 121 +- .../transports/grpc.py | 145 +- .../transports/grpc_asyncio.py | 147 +- .../services/vizier_service/__init__.py | 4 +- .../services/vizier_service/async_client.py | 554 +- .../services/vizier_service/client.py | 636 +- .../services/vizier_service/pagers.py | 79 +- .../vizier_service/transports/__init__.py | 10 +- .../vizier_service/transports/base.py | 292 +- .../vizier_service/transports/grpc.py | 278 +- .../vizier_service/transports/grpc_asyncio.py | 287 +- .../aiplatform_v1beta1/types/__init__.py | 610 +- .../types/accelerator_type.py | 5 +- .../aiplatform_v1beta1/types/annotation.py | 21 +- .../types/annotation_spec.py | 13 +- .../aiplatform_v1beta1/types/artifact.py | 22 +- .../types/batch_prediction_job.py | 107 +- .../types/completion_stats.py | 5 +- .../cloud/aiplatform_v1beta1/types/context.py | 17 +- .../aiplatform_v1beta1/types/custom_job.py | 78 +- .../aiplatform_v1beta1/types/data_item.py | 17 +- .../types/data_labeling_job.py | 71 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 32 +- .../types/dataset_service.py | 102 +- .../types/deployed_model_ref.py | 5 +- .../types/encryption_spec.py | 5 +- .../aiplatform_v1beta1/types/endpoint.py | 40 +- .../types/endpoint_service.py | 68 +- .../cloud/aiplatform_v1beta1/types/env_var.py | 5 +- .../cloud/aiplatform_v1beta1/types/event.py | 14 +- .../aiplatform_v1beta1/types/execution.py | 22 +- .../aiplatform_v1beta1/types/explanation.py | 104 +- .../types/explanation_metadata.py | 72 +- .../types/feature_monitoring_stats.py | 13 +- .../types/hyperparameter_tuning_job.py | 45 +- google/cloud/aiplatform_v1beta1/types/io.py | 12 +- .../aiplatform_v1beta1/types/job_service.py | 181 +- .../aiplatform_v1beta1/types/job_state.py | 5 +- .../types/lineage_subgraph.py | 17 +- .../types/machine_resources.py | 36 +- .../types/manual_batch_tuning_parameters.py | 6 +- .../types/metadata_schema.py | 14 +- .../types/metadata_service.py | 148 +- .../types/metadata_store.py | 17 +- .../types/migratable_resource.py | 37 +- .../types/migration_service.py | 87 +- .../cloud/aiplatform_v1beta1/types/model.py | 63 +- .../types/model_deployment_monitoring_job.py | 105 +- .../types/model_evaluation.py | 26 +- .../types/model_evaluation_slice.py | 18 +- .../types/model_monitoring.py | 57 +- .../aiplatform_v1beta1/types/model_service.py | 98 +- .../aiplatform_v1beta1/types/operation.py | 23 +- .../types/pipeline_service.py | 30 +- .../types/pipeline_state.py | 5 +- .../types/prediction_service.py | 42 +- .../types/specialist_pool.py | 5 +- .../types/specialist_pool_service.py | 46 +- .../cloud/aiplatform_v1beta1/types/study.py | 166 +- .../types/training_pipeline.py | 82 +- .../types/user_action_reference.py | 9 +- .../types/vizier_service.py | 98 +- noxfile.py | 54 +- tests/unit/gapic/aiplatform_v1/__init__.py | 1 - .../aiplatform_v1/test_dataset_service.py | 2265 ++++---- .../aiplatform_v1/test_endpoint_service.py | 1610 +++--- .../gapic/aiplatform_v1/test_job_service.py | 3689 ++++++------ .../aiplatform_v1/test_migration_service.py | 940 +-- .../gapic/aiplatform_v1/test_model_service.py | 2366 ++++---- .../aiplatform_v1/test_pipeline_service.py | 1283 ++-- .../test_specialist_pool_service.py | 1156 ++-- .../unit/gapic/aiplatform_v1beta1/__init__.py | 1 - .../test_dataset_service.py | 2269 ++++---- .../test_endpoint_service.py | 1614 +++--- .../aiplatform_v1beta1/test_job_service.py | 5147 ++++++++--------- .../test_metadata_service.py | 4984 +++++++--------- .../test_migration_service.py | 944 +-- .../aiplatform_v1beta1/test_model_service.py | 2370 ++++---- .../test_pipeline_service.py | 1291 +++-- .../test_specialist_pool_service.py | 1156 ++-- .../aiplatform_v1beta1/test_vizier_service.py | 2624 ++++----- 348 files changed, 35602 insertions(+), 37793 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index c05116a68c..98e68be241 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -347,13 +347,9 @@ intersphinx_mapping = { "python": ("https://python.readthedocs.org/en/latest/", None), "google-auth": ("https://googleapis.dev/python/google-auth/latest/", None), - "google.api_core": ( - "https://googleapis.dev/python/google-api-core/latest/", - None, - ), + "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,), "grpc": ("https://grpc.github.io/grpc/python/", None), "proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None), - } diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py index e99be5a9d2..fb2668afb5 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py @@ -15,24 +15,42 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_classification import ImageClassificationPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_object_detection import ImageObjectDetectionPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_segmentation import ImageSegmentationPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_classification import TextClassificationPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_extraction import TextExtractionPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_sentiment import TextSentimentPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_action_recognition import VideoActionRecognitionPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_classification import VideoClassificationPredictionInstance -from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_object_tracking import VideoObjectTrackingPredictionInstance +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_classification import ( + ImageClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_object_detection import ( + ImageObjectDetectionPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_segmentation import ( + ImageSegmentationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_classification import ( + TextClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_extraction import ( + TextExtractionPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_sentiment import ( + TextSentimentPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_action_recognition import ( + VideoActionRecognitionPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_classification import ( + VideoClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_object_tracking import ( + VideoObjectTrackingPredictionInstance, +) __all__ = ( - 'ImageClassificationPredictionInstance', - 'ImageObjectDetectionPredictionInstance', - 'ImageSegmentationPredictionInstance', - 'TextClassificationPredictionInstance', - 'TextExtractionPredictionInstance', - 'TextSentimentPredictionInstance', - 'VideoActionRecognitionPredictionInstance', - 'VideoClassificationPredictionInstance', - 'VideoObjectTrackingPredictionInstance', + "ImageClassificationPredictionInstance", + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py index c68b05e778..f6d9a128ad 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py @@ -27,13 +27,13 @@ __all__ = ( - 'ImageObjectDetectionPredictionInstance', - 'ImageSegmentationPredictionInstance', - 'TextClassificationPredictionInstance', - 'TextExtractionPredictionInstance', - 'TextSentimentPredictionInstance', - 'VideoActionRecognitionPredictionInstance', - 'VideoClassificationPredictionInstance', - 'VideoObjectTrackingPredictionInstance', -'ImageClassificationPredictionInstance', + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", + "ImageClassificationPredictionInstance", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py index aacf581e2e..041fe6cdb1 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py @@ -15,42 +15,24 @@ # limitations under the License. # -from .image_classification import ( - ImageClassificationPredictionInstance, -) -from .image_object_detection import ( - ImageObjectDetectionPredictionInstance, -) -from .image_segmentation import ( - ImageSegmentationPredictionInstance, -) -from .text_classification import ( - TextClassificationPredictionInstance, -) -from .text_extraction import ( - TextExtractionPredictionInstance, -) -from .text_sentiment import ( - TextSentimentPredictionInstance, -) -from .video_action_recognition import ( - VideoActionRecognitionPredictionInstance, -) -from .video_classification import ( - VideoClassificationPredictionInstance, -) -from .video_object_tracking import ( - VideoObjectTrackingPredictionInstance, -) +from .image_classification import ImageClassificationPredictionInstance +from .image_object_detection import ImageObjectDetectionPredictionInstance +from .image_segmentation import ImageSegmentationPredictionInstance +from .text_classification import TextClassificationPredictionInstance +from .text_extraction import TextExtractionPredictionInstance +from .text_sentiment import TextSentimentPredictionInstance +from .video_action_recognition import VideoActionRecognitionPredictionInstance +from .video_classification import VideoClassificationPredictionInstance +from .video_object_tracking import VideoObjectTrackingPredictionInstance __all__ = ( - 'ImageClassificationPredictionInstance', - 'ImageObjectDetectionPredictionInstance', - 'ImageSegmentationPredictionInstance', - 'TextClassificationPredictionInstance', - 'TextExtractionPredictionInstance', - 'TextSentimentPredictionInstance', - 'VideoActionRecognitionPredictionInstance', - 'VideoClassificationPredictionInstance', - 'VideoObjectTrackingPredictionInstance', + "ImageClassificationPredictionInstance", + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py index 2b7e94a11b..b5fa9b4dbf 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'ImageClassificationPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"ImageClassificationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py index a7ad135173..45752ce7e2 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'ImageObjectDetectionPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"ImageObjectDetectionPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py index fb663cb849..cb436d7029 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'ImageSegmentationPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"ImageSegmentationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py index 1d54c594d9..ceff5308b7 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'TextClassificationPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"TextClassificationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py index 6260e4eca9..2e96216466 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'TextExtractionPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"TextExtractionPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py index ca47c08fc2..37353ad806 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'TextSentimentPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"TextSentimentPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py index 5e72ebbeae..6de5665312 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'VideoActionRecognitionPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"VideoActionRecognitionPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py index 2a302fc41f..ab7c0edfe1 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'VideoClassificationPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"VideoClassificationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py index 7f1d7b371b..f797f58f4e 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.instance', - manifest={ - 'VideoObjectTrackingPredictionInstance', - }, + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"VideoObjectTrackingPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py index 7a3e372796..c046f4d7e5 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py @@ -15,18 +15,30 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_classification import ImageClassificationPredictionParams -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_object_detection import ImageObjectDetectionPredictionParams -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_segmentation import ImageSegmentationPredictionParams -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_action_recognition import VideoActionRecognitionPredictionParams -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_classification import VideoClassificationPredictionParams -from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_object_tracking import VideoObjectTrackingPredictionParams +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_classification import ( + ImageClassificationPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_object_detection import ( + ImageObjectDetectionPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_segmentation import ( + ImageSegmentationPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_action_recognition import ( + VideoActionRecognitionPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_classification import ( + VideoClassificationPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_object_tracking import ( + VideoObjectTrackingPredictionParams, +) __all__ = ( - 'ImageClassificationPredictionParams', - 'ImageObjectDetectionPredictionParams', - 'ImageSegmentationPredictionParams', - 'VideoActionRecognitionPredictionParams', - 'VideoClassificationPredictionParams', - 'VideoObjectTrackingPredictionParams', + "ImageClassificationPredictionParams", + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py index 0e358981b3..79fb1c2097 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py @@ -24,10 +24,10 @@ __all__ = ( - 'ImageObjectDetectionPredictionParams', - 'ImageSegmentationPredictionParams', - 'VideoActionRecognitionPredictionParams', - 'VideoClassificationPredictionParams', - 'VideoObjectTrackingPredictionParams', -'ImageClassificationPredictionParams', + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", + "ImageClassificationPredictionParams", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py index 4f53fda062..2f2c29bba5 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py @@ -15,30 +15,18 @@ # limitations under the License. # -from .image_classification import ( - ImageClassificationPredictionParams, -) -from .image_object_detection import ( - ImageObjectDetectionPredictionParams, -) -from .image_segmentation import ( - ImageSegmentationPredictionParams, -) -from .video_action_recognition import ( - VideoActionRecognitionPredictionParams, -) -from .video_classification import ( - VideoClassificationPredictionParams, -) -from .video_object_tracking import ( - VideoObjectTrackingPredictionParams, -) +from .image_classification import ImageClassificationPredictionParams +from .image_object_detection import ImageObjectDetectionPredictionParams +from .image_segmentation import ImageSegmentationPredictionParams +from .video_action_recognition import VideoActionRecognitionPredictionParams +from .video_classification import VideoClassificationPredictionParams +from .video_object_tracking import VideoObjectTrackingPredictionParams __all__ = ( - 'ImageClassificationPredictionParams', - 'ImageObjectDetectionPredictionParams', - 'ImageSegmentationPredictionParams', - 'VideoActionRecognitionPredictionParams', - 'VideoClassificationPredictionParams', - 'VideoObjectTrackingPredictionParams', + "ImageClassificationPredictionParams", + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py index b29f91c772..3a9efd0ea2 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.params', - manifest={ - 'ImageClassificationPredictionParams', - }, + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"ImageClassificationPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py index 7b34fe0395..c37507a4e0 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.params', - manifest={ - 'ImageObjectDetectionPredictionParams', - }, + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"ImageObjectDetectionPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py index 3b2f2c3ff2..108cff107b 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.params', - manifest={ - 'ImageSegmentationPredictionParams', - }, + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"ImageSegmentationPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py index 9fbd7a6b6a..66f1f19e76 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.params', - manifest={ - 'VideoActionRecognitionPredictionParams', - }, + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"VideoActionRecognitionPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py index cf79e22d5f..bfe8df9f5c 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.params', - manifest={ - 'VideoClassificationPredictionParams', - }, + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"VideoClassificationPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py index 1b1b615d0a..899de1050a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.params', - manifest={ - 'VideoObjectTrackingPredictionParams', - }, + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"VideoObjectTrackingPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py index 01d2f8177a..d8e2b782c2 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py @@ -15,26 +15,46 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.classification import ClassificationPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_object_detection import ImageObjectDetectionPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_segmentation import ImageSegmentationPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_classification import TabularClassificationPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_regression import TabularRegressionPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_extraction import TextExtractionPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_sentiment import TextSentimentPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_action_recognition import VideoActionRecognitionPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_classification import VideoClassificationPredictionResult -from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_object_tracking import VideoObjectTrackingPredictionResult +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.classification import ( + ClassificationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_object_detection import ( + ImageObjectDetectionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_segmentation import ( + ImageSegmentationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_classification import ( + TabularClassificationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_regression import ( + TabularRegressionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_extraction import ( + TextExtractionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_sentiment import ( + TextSentimentPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_action_recognition import ( + VideoActionRecognitionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_classification import ( + VideoClassificationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_object_tracking import ( + VideoObjectTrackingPredictionResult, +) __all__ = ( - 'ClassificationPredictionResult', - 'ImageObjectDetectionPredictionResult', - 'ImageSegmentationPredictionResult', - 'TabularClassificationPredictionResult', - 'TabularRegressionPredictionResult', - 'TextExtractionPredictionResult', - 'TextSentimentPredictionResult', - 'VideoActionRecognitionPredictionResult', - 'VideoClassificationPredictionResult', - 'VideoObjectTrackingPredictionResult', + "ClassificationPredictionResult", + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py index 42f26f575f..91fae5a3b1 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py @@ -28,14 +28,14 @@ __all__ = ( - 'ImageObjectDetectionPredictionResult', - 'ImageSegmentationPredictionResult', - 'TabularClassificationPredictionResult', - 'TabularRegressionPredictionResult', - 'TextExtractionPredictionResult', - 'TextSentimentPredictionResult', - 'VideoActionRecognitionPredictionResult', - 'VideoClassificationPredictionResult', - 'VideoObjectTrackingPredictionResult', -'ClassificationPredictionResult', + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", + "ClassificationPredictionResult", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py index 019d5ea59c..a0fd2058e0 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py @@ -15,46 +15,26 @@ # limitations under the License. # -from .classification import ( - ClassificationPredictionResult, -) -from .image_object_detection import ( - ImageObjectDetectionPredictionResult, -) -from .image_segmentation import ( - ImageSegmentationPredictionResult, -) -from .tabular_classification import ( - TabularClassificationPredictionResult, -) -from .tabular_regression import ( - TabularRegressionPredictionResult, -) -from .text_extraction import ( - TextExtractionPredictionResult, -) -from .text_sentiment import ( - TextSentimentPredictionResult, -) -from .video_action_recognition import ( - VideoActionRecognitionPredictionResult, -) -from .video_classification import ( - VideoClassificationPredictionResult, -) -from .video_object_tracking import ( - VideoObjectTrackingPredictionResult, -) +from .classification import ClassificationPredictionResult +from .image_object_detection import ImageObjectDetectionPredictionResult +from .image_segmentation import ImageSegmentationPredictionResult +from .tabular_classification import TabularClassificationPredictionResult +from .tabular_regression import TabularRegressionPredictionResult +from .text_extraction import TextExtractionPredictionResult +from .text_sentiment import TextSentimentPredictionResult +from .video_action_recognition import VideoActionRecognitionPredictionResult +from .video_classification import VideoClassificationPredictionResult +from .video_object_tracking import VideoObjectTrackingPredictionResult __all__ = ( - 'ClassificationPredictionResult', - 'ImageObjectDetectionPredictionResult', - 'ImageSegmentationPredictionResult', - 'TabularClassificationPredictionResult', - 'TabularRegressionPredictionResult', - 'TextExtractionPredictionResult', - 'TextSentimentPredictionResult', - 'VideoActionRecognitionPredictionResult', - 'VideoClassificationPredictionResult', - 'VideoObjectTrackingPredictionResult', + "ClassificationPredictionResult", + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py index 2ae1a3a9cf..cfc8e2e602 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'ClassificationPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"ClassificationPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py index 2987851e58..31d37010db 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py @@ -22,10 +22,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'ImageObjectDetectionPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"ImageObjectDetectionPredictionResult",}, ) @@ -60,9 +58,7 @@ class ImageObjectDetectionPredictionResult(proto.Message): confidences = proto.RepeatedField(proto.FLOAT, number=3) - bboxes = proto.RepeatedField(proto.MESSAGE, number=4, - message=struct.ListValue, - ) + bboxes = proto.RepeatedField(proto.MESSAGE, number=4, message=struct.ListValue,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py index c12b105a2f..1261f19723 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'ImageSegmentationPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"ImageSegmentationPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py index 6ffe672140..7e78051467 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'TabularClassificationPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TabularClassificationPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py index f26cfa1b46..c813f3e45c 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'TabularRegressionPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TabularRegressionPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py index 05234d1324..201f10d08a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'TextExtractionPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TextExtractionPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py index 27501ba0a6..73c670f4ec 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'TextSentimentPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TextSentimentPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py index ad88398dc6..486853c63d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py @@ -23,10 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'VideoActionRecognitionPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"VideoActionRecognitionPredictionResult",}, ) @@ -64,17 +62,13 @@ class VideoActionRecognitionPredictionResult(proto.Message): display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field(proto.MESSAGE, number=4, - message=duration.Duration, + time_segment_start = proto.Field( + proto.MESSAGE, number=4, message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, - message=duration.Duration, - ) + time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) - confidence = proto.Field(proto.MESSAGE, number=6, - message=wrappers.FloatValue, - ) + confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py index 12f042e10e..c043547d04 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py @@ -23,10 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'VideoClassificationPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"VideoClassificationPredictionResult",}, ) @@ -80,17 +78,13 @@ class VideoClassificationPredictionResult(proto.Message): type_ = proto.Field(proto.STRING, number=3) - time_segment_start = proto.Field(proto.MESSAGE, number=4, - message=duration.Duration, + time_segment_start = proto.Field( + proto.MESSAGE, number=4, message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, - message=duration.Duration, - ) + time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) - confidence = proto.Field(proto.MESSAGE, number=6, - message=wrappers.FloatValue, - ) + confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py index 672c039bc6..d1b515a895 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py @@ -23,10 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.predict.prediction', - manifest={ - 'VideoObjectTrackingPredictionResult', - }, + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"VideoObjectTrackingPredictionResult",}, ) @@ -64,6 +62,7 @@ class VideoObjectTrackingPredictionResult(proto.Message): bounding boxes in the frames identify the same object. """ + class Frame(proto.Message): r"""The fields ``xMin``, ``xMax``, ``yMin``, and ``yMax`` refer to a bounding box, i.e. the rectangle over the video frame pinpointing @@ -88,45 +87,29 @@ class Frame(proto.Message): box. """ - time_offset = proto.Field(proto.MESSAGE, number=1, - message=duration.Duration, - ) + time_offset = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) - x_min = proto.Field(proto.MESSAGE, number=2, - message=wrappers.FloatValue, - ) + x_min = proto.Field(proto.MESSAGE, number=2, message=wrappers.FloatValue,) - x_max = proto.Field(proto.MESSAGE, number=3, - message=wrappers.FloatValue, - ) + x_max = proto.Field(proto.MESSAGE, number=3, message=wrappers.FloatValue,) - y_min = proto.Field(proto.MESSAGE, number=4, - message=wrappers.FloatValue, - ) + y_min = proto.Field(proto.MESSAGE, number=4, message=wrappers.FloatValue,) - y_max = proto.Field(proto.MESSAGE, number=5, - message=wrappers.FloatValue, - ) + y_max = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) id = proto.Field(proto.STRING, number=1) display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field(proto.MESSAGE, number=3, - message=duration.Duration, + time_segment_start = proto.Field( + proto.MESSAGE, number=3, message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=4, - message=duration.Duration, - ) + time_segment_end = proto.Field(proto.MESSAGE, number=4, message=duration.Duration,) - confidence = proto.Field(proto.MESSAGE, number=5, - message=wrappers.FloatValue, - ) + confidence = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) - frames = proto.RepeatedField(proto.MESSAGE, number=6, - message=Frame, - ) + frames = proto.RepeatedField(proto.MESSAGE, number=6, message=Frame,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py index 1f57aea67f..f8620bb25d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py @@ -15,56 +15,106 @@ # limitations under the License. # -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import AutoMlImageClassification -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import AutoMlImageClassificationInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import AutoMlImageClassificationMetadata -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import AutoMlImageObjectDetection -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import AutoMlImageObjectDetectionInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import AutoMlImageObjectDetectionMetadata -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import AutoMlImageSegmentation -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import AutoMlImageSegmentationInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import AutoMlImageSegmentationMetadata -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import AutoMlTables -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import AutoMlTablesInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import AutoMlTablesMetadata -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import AutoMlTextClassification -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import AutoMlTextClassificationInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import AutoMlTextExtraction -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import AutoMlTextExtractionInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import AutoMlTextSentiment -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import AutoMlTextSentimentInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import AutoMlVideoActionRecognition -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import AutoMlVideoActionRecognitionInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import AutoMlVideoClassification -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import AutoMlVideoClassificationInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import AutoMlVideoObjectTracking -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import AutoMlVideoObjectTrackingInputs -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( + AutoMlImageClassification, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( + AutoMlImageClassificationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( + AutoMlImageClassificationMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( + AutoMlImageObjectDetection, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( + AutoMlImageObjectDetectionInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( + AutoMlImageObjectDetectionMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( + AutoMlImageSegmentation, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( + AutoMlImageSegmentationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( + AutoMlImageSegmentationMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( + AutoMlTables, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( + AutoMlTablesInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( + AutoMlTablesMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import ( + AutoMlTextClassification, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import ( + AutoMlTextClassificationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import ( + AutoMlTextExtraction, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import ( + AutoMlTextExtractionInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import ( + AutoMlTextSentiment, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import ( + AutoMlTextSentimentInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import ( + AutoMlVideoActionRecognition, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import ( + AutoMlVideoActionRecognitionInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import ( + AutoMlVideoClassification, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import ( + AutoMlVideoClassificationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import ( + AutoMlVideoObjectTracking, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import ( + AutoMlVideoObjectTrackingInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.export_evaluated_data_items_config import ( + ExportEvaluatedDataItemsConfig, +) __all__ = ( - 'AutoMlImageClassification', - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - 'ExportEvaluatedDataItemsConfig', + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py index 135e04f228..34958e5add 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py @@ -43,29 +43,29 @@ __all__ = ( - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - 'ExportEvaluatedDataItemsConfig', -'AutoMlImageClassification', + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", + "AutoMlImageClassification", ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py index 2d7d19c057..a15aa2c041 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py @@ -59,34 +59,32 @@ AutoMlVideoObjectTracking, AutoMlVideoObjectTrackingInputs, ) -from .export_evaluated_data_items_config import ( - ExportEvaluatedDataItemsConfig, -) +from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig __all__ = ( - 'AutoMlImageClassification', - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - 'ExportEvaluatedDataItemsConfig', + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py index 530007c977..f7e13c60b7 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", manifest={ - 'AutoMlImageClassification', - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", }, ) @@ -39,12 +39,12 @@ class AutoMlImageClassification(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlImageClassificationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageClassificationInputs", ) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlImageClassificationMetadata', + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageClassificationMetadata", ) @@ -92,6 +92,7 @@ class AutoMlImageClassificationInputs(proto.Message): be trained (i.e. assuming that for each image multiple annotations may be applicable). """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -100,9 +101,7 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 3 MOBILE_TF_HIGH_ACCURACY_1 = 4 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) base_model_id = proto.Field(proto.STRING, number=2) @@ -127,6 +126,7 @@ class AutoMlImageClassificationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ + class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -135,8 +135,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field(proto.ENUM, number=2, - enum=SuccessfulStopReason, + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py index 9aa8ea5b3d..1c2c9f83b7 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", manifest={ - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", }, ) @@ -39,12 +39,12 @@ class AutoMlImageObjectDetection(proto.Message): The metadata information """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlImageObjectDetectionInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageObjectDetectionInputs", ) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlImageObjectDetectionMetadata', + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageObjectDetectionMetadata", ) @@ -80,6 +80,7 @@ class AutoMlImageObjectDetectionInputs(proto.Message): training before the entire training budget has been used. """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -89,9 +90,7 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 4 MOBILE_TF_HIGH_ACCURACY_1 = 5 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -112,6 +111,7 @@ class AutoMlImageObjectDetectionMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ + class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -120,8 +120,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field(proto.ENUM, number=2, - enum=SuccessfulStopReason, + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py index 9188939a09..a81103657e 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", manifest={ - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", }, ) @@ -39,12 +39,12 @@ class AutoMlImageSegmentation(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlImageSegmentationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageSegmentationInputs", ) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlImageSegmentationMetadata', + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageSegmentationMetadata", ) @@ -76,6 +76,7 @@ class AutoMlImageSegmentationInputs(proto.Message): ``base`` model must be in the same Project and Location as the new Model to train, and have the same modelType. """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -83,9 +84,7 @@ class ModelType(proto.Enum): CLOUD_LOW_ACCURACY_1 = 2 MOBILE_TF_LOW_LATENCY_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -106,6 +105,7 @@ class AutoMlImageSegmentationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ + class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -114,8 +114,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field(proto.ENUM, number=2, - enum=SuccessfulStopReason, + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py index 1efe804ca5..1c3d0c8da7 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py @@ -18,16 +18,14 @@ import proto # type: ignore -from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types import export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types import ( + export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config, +) __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata",}, ) @@ -41,13 +39,9 @@ class AutoMlTables(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTablesInputs', - ) + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTablesInputs",) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlTablesMetadata', - ) + metadata = proto.Field(proto.MESSAGE, number=2, message="AutoMlTablesMetadata",) class AutoMlTablesInputs(proto.Message): @@ -152,6 +146,7 @@ class AutoMlTablesInputs(proto.Message): configuration is absent, then the export is not performed. """ + class Transformation(proto.Message): r""" @@ -173,6 +168,7 @@ class Transformation(proto.Message): repeated_text (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation): """ + class AutoTransformation(proto.Message): r"""Training pipeline will infer the proper transformation based on the statistic of dataset. @@ -347,48 +343,76 @@ class TextArrayTransformation(proto.Message): column_name = proto.Field(proto.STRING, number=1) - auto = proto.Field(proto.MESSAGE, number=1, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.AutoTransformation', + auto = proto.Field( + proto.MESSAGE, + number=1, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.AutoTransformation", ) - numeric = proto.Field(proto.MESSAGE, number=2, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.NumericTransformation', + numeric = proto.Field( + proto.MESSAGE, + number=2, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.NumericTransformation", ) - categorical = proto.Field(proto.MESSAGE, number=3, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.CategoricalTransformation', + categorical = proto.Field( + proto.MESSAGE, + number=3, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.CategoricalTransformation", ) - timestamp = proto.Field(proto.MESSAGE, number=4, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.TimestampTransformation', + timestamp = proto.Field( + proto.MESSAGE, + number=4, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TimestampTransformation", ) - text = proto.Field(proto.MESSAGE, number=5, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.TextTransformation', + text = proto.Field( + proto.MESSAGE, + number=5, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TextTransformation", ) - repeated_numeric = proto.Field(proto.MESSAGE, number=6, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.NumericArrayTransformation', + repeated_numeric = proto.Field( + proto.MESSAGE, + number=6, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.NumericArrayTransformation", ) - repeated_categorical = proto.Field(proto.MESSAGE, number=7, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.CategoricalArrayTransformation', + repeated_categorical = proto.Field( + proto.MESSAGE, + number=7, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.CategoricalArrayTransformation", ) - repeated_text = proto.Field(proto.MESSAGE, number=8, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.TextArrayTransformation', + repeated_text = proto.Field( + proto.MESSAGE, + number=8, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TextArrayTransformation", ) - optimization_objective_recall_value = proto.Field(proto.FLOAT, number=5, oneof='additional_optimization_objective_config') + optimization_objective_recall_value = proto.Field( + proto.FLOAT, number=5, oneof="additional_optimization_objective_config" + ) - optimization_objective_precision_value = proto.Field(proto.FLOAT, number=6, oneof='additional_optimization_objective_config') + optimization_objective_precision_value = proto.Field( + proto.FLOAT, number=6, oneof="additional_optimization_objective_config" + ) prediction_type = proto.Field(proto.STRING, number=1) target_column = proto.Field(proto.STRING, number=2) - transformations = proto.RepeatedField(proto.MESSAGE, number=3, - message=Transformation, + transformations = proto.RepeatedField( + proto.MESSAGE, number=3, message=Transformation, ) optimization_objective = proto.Field(proto.STRING, number=4) @@ -399,7 +423,9 @@ class TextArrayTransformation(proto.Message): weight_column_name = proto.Field(proto.STRING, number=9) - export_evaluated_data_items_config = proto.Field(proto.MESSAGE, number=10, + export_evaluated_data_items_config = proto.Field( + proto.MESSAGE, + number=10, message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig, ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py index adcd3a46fb..205deaf375 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTextClassification", "AutoMlTextClassificationInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlTextClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTextClassificationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlTextClassificationInputs", ) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py index f6d6064504..fad28847af 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTextExtraction", "AutoMlTextExtractionInputs",}, ) @@ -36,9 +33,7 @@ class AutoMlTextExtraction(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTextExtractionInputs', - ) + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextExtractionInputs",) class AutoMlTextExtractionInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py index 5d67713e3d..ca80a44d1d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTextSentiment", "AutoMlTextSentimentInputs",}, ) @@ -36,9 +33,7 @@ class AutoMlTextSentiment(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTextSentimentInputs', - ) + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextSentimentInputs",) class AutoMlTextSentimentInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py index 06653758a7..1a20a6d725 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlVideoActionRecognition", "AutoMlVideoActionRecognitionInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlVideoActionRecognition(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlVideoActionRecognitionInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoActionRecognitionInputs", ) @@ -48,15 +45,14 @@ class AutoMlVideoActionRecognitionInputs(proto.Message): model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoActionRecognitionInputs.ModelType): """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 CLOUD = 1 MOBILE_VERSATILE_1 = 2 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py index 486e4d0ecb..ba7f2d5b21 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlVideoClassification", "AutoMlVideoClassificationInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlVideoClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlVideoClassificationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoClassificationInputs", ) @@ -48,6 +45,7 @@ class AutoMlVideoClassificationInputs(proto.Message): model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoClassificationInputs.ModelType): """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -55,9 +53,7 @@ class ModelType(proto.Enum): MOBILE_VERSATILE_1 = 2 MOBILE_JETSON_VERSATILE_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py index de660f7d1d..0ecb1113d9 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlVideoObjectTracking(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlVideoObjectTrackingInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoObjectTrackingInputs", ) @@ -48,6 +45,7 @@ class AutoMlVideoObjectTrackingInputs(proto.Message): model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoObjectTrackingInputs.ModelType): """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -58,9 +56,7 @@ class ModelType(proto.Enum): MOBILE_JETSON_VERSATILE_1 = 5 MOBILE_JETSON_LOW_LATENCY_1 = 6 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py index a5b1fcb542..dc8a629412 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1.schema.trainingjob.definition', - manifest={ - 'ExportEvaluatedDataItemsConfig', - }, + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"ExportEvaluatedDataItemsConfig",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py index 62c5942a51..2f514ac4ed 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py @@ -15,24 +15,42 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_classification import ImageClassificationPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_object_detection import ImageObjectDetectionPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_segmentation import ImageSegmentationPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_classification import TextClassificationPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_extraction import TextExtractionPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_sentiment import TextSentimentPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_action_recognition import VideoActionRecognitionPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_classification import VideoClassificationPredictionInstance -from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_object_tracking import VideoObjectTrackingPredictionInstance +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_classification import ( + ImageClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_object_detection import ( + ImageObjectDetectionPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_segmentation import ( + ImageSegmentationPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_classification import ( + TextClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_extraction import ( + TextExtractionPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.text_sentiment import ( + TextSentimentPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_action_recognition import ( + VideoActionRecognitionPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_classification import ( + VideoClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.video_object_tracking import ( + VideoObjectTrackingPredictionInstance, +) __all__ = ( - 'ImageClassificationPredictionInstance', - 'ImageObjectDetectionPredictionInstance', - 'ImageSegmentationPredictionInstance', - 'TextClassificationPredictionInstance', - 'TextExtractionPredictionInstance', - 'TextSentimentPredictionInstance', - 'VideoActionRecognitionPredictionInstance', - 'VideoClassificationPredictionInstance', - 'VideoObjectTrackingPredictionInstance', + "ImageClassificationPredictionInstance", + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py index c68b05e778..f6d9a128ad 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py @@ -27,13 +27,13 @@ __all__ = ( - 'ImageObjectDetectionPredictionInstance', - 'ImageSegmentationPredictionInstance', - 'TextClassificationPredictionInstance', - 'TextExtractionPredictionInstance', - 'TextSentimentPredictionInstance', - 'VideoActionRecognitionPredictionInstance', - 'VideoClassificationPredictionInstance', - 'VideoObjectTrackingPredictionInstance', -'ImageClassificationPredictionInstance', + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", + "ImageClassificationPredictionInstance", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py index aacf581e2e..041fe6cdb1 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py @@ -15,42 +15,24 @@ # limitations under the License. # -from .image_classification import ( - ImageClassificationPredictionInstance, -) -from .image_object_detection import ( - ImageObjectDetectionPredictionInstance, -) -from .image_segmentation import ( - ImageSegmentationPredictionInstance, -) -from .text_classification import ( - TextClassificationPredictionInstance, -) -from .text_extraction import ( - TextExtractionPredictionInstance, -) -from .text_sentiment import ( - TextSentimentPredictionInstance, -) -from .video_action_recognition import ( - VideoActionRecognitionPredictionInstance, -) -from .video_classification import ( - VideoClassificationPredictionInstance, -) -from .video_object_tracking import ( - VideoObjectTrackingPredictionInstance, -) +from .image_classification import ImageClassificationPredictionInstance +from .image_object_detection import ImageObjectDetectionPredictionInstance +from .image_segmentation import ImageSegmentationPredictionInstance +from .text_classification import TextClassificationPredictionInstance +from .text_extraction import TextExtractionPredictionInstance +from .text_sentiment import TextSentimentPredictionInstance +from .video_action_recognition import VideoActionRecognitionPredictionInstance +from .video_classification import VideoClassificationPredictionInstance +from .video_object_tracking import VideoObjectTrackingPredictionInstance __all__ = ( - 'ImageClassificationPredictionInstance', - 'ImageObjectDetectionPredictionInstance', - 'ImageSegmentationPredictionInstance', - 'TextClassificationPredictionInstance', - 'TextExtractionPredictionInstance', - 'TextSentimentPredictionInstance', - 'VideoActionRecognitionPredictionInstance', - 'VideoClassificationPredictionInstance', - 'VideoObjectTrackingPredictionInstance', + "ImageClassificationPredictionInstance", + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py index c0a0d477a4..84b1ef0bbe 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'ImageClassificationPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"ImageClassificationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py index 32cdc492ad..79c3efc2c6 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'ImageObjectDetectionPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"ImageObjectDetectionPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py index 0e1d5293ea..5a3232c6d2 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'ImageSegmentationPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"ImageSegmentationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py index 3ea5a96d5d..a615dc7e49 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'TextClassificationPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"TextClassificationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py index d256b7d008..c6fecf80b7 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'TextExtractionPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"TextExtractionPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py index 0e0a339a1c..69836d0e96 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'TextSentimentPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"TextSentimentPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py index 14a4e4ffec..ae3935d387 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'VideoActionRecognitionPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"VideoActionRecognitionPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py index 77e8d9e1c0..2f944bb99e 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'VideoClassificationPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"VideoClassificationPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py index ab4b3f282f..e635b5174b 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.instance', - manifest={ - 'VideoObjectTrackingPredictionInstance', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.instance", + manifest={"VideoObjectTrackingPredictionInstance",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py index 0de177503e..dc7cd58e9a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py @@ -15,18 +15,30 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_classification import ImageClassificationPredictionParams -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_object_detection import ImageObjectDetectionPredictionParams -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_segmentation import ImageSegmentationPredictionParams -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_action_recognition import VideoActionRecognitionPredictionParams -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_classification import VideoClassificationPredictionParams -from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_object_tracking import VideoObjectTrackingPredictionParams +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_classification import ( + ImageClassificationPredictionParams, +) +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_object_detection import ( + ImageObjectDetectionPredictionParams, +) +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_segmentation import ( + ImageSegmentationPredictionParams, +) +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_action_recognition import ( + VideoActionRecognitionPredictionParams, +) +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_classification import ( + VideoClassificationPredictionParams, +) +from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.video_object_tracking import ( + VideoObjectTrackingPredictionParams, +) __all__ = ( - 'ImageClassificationPredictionParams', - 'ImageObjectDetectionPredictionParams', - 'ImageSegmentationPredictionParams', - 'VideoActionRecognitionPredictionParams', - 'VideoClassificationPredictionParams', - 'VideoObjectTrackingPredictionParams', + "ImageClassificationPredictionParams", + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py index 0e358981b3..79fb1c2097 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py @@ -24,10 +24,10 @@ __all__ = ( - 'ImageObjectDetectionPredictionParams', - 'ImageSegmentationPredictionParams', - 'VideoActionRecognitionPredictionParams', - 'VideoClassificationPredictionParams', - 'VideoObjectTrackingPredictionParams', -'ImageClassificationPredictionParams', + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", + "ImageClassificationPredictionParams", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py index 4f53fda062..2f2c29bba5 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py @@ -15,30 +15,18 @@ # limitations under the License. # -from .image_classification import ( - ImageClassificationPredictionParams, -) -from .image_object_detection import ( - ImageObjectDetectionPredictionParams, -) -from .image_segmentation import ( - ImageSegmentationPredictionParams, -) -from .video_action_recognition import ( - VideoActionRecognitionPredictionParams, -) -from .video_classification import ( - VideoClassificationPredictionParams, -) -from .video_object_tracking import ( - VideoObjectTrackingPredictionParams, -) +from .image_classification import ImageClassificationPredictionParams +from .image_object_detection import ImageObjectDetectionPredictionParams +from .image_segmentation import ImageSegmentationPredictionParams +from .video_action_recognition import VideoActionRecognitionPredictionParams +from .video_classification import VideoClassificationPredictionParams +from .video_object_tracking import VideoObjectTrackingPredictionParams __all__ = ( - 'ImageClassificationPredictionParams', - 'ImageObjectDetectionPredictionParams', - 'ImageSegmentationPredictionParams', - 'VideoActionRecognitionPredictionParams', - 'VideoClassificationPredictionParams', - 'VideoObjectTrackingPredictionParams', + "ImageClassificationPredictionParams", + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py index 1bfe57e1e6..681a8c3d87 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.params', - manifest={ - 'ImageClassificationPredictionParams', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.params", + manifest={"ImageClassificationPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py index ba86d17656..146dd324b7 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.params', - manifest={ - 'ImageObjectDetectionPredictionParams', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.params", + manifest={"ImageObjectDetectionPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py index ab5b028025..aa11739a61 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.params', - manifest={ - 'ImageSegmentationPredictionParams', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.params", + manifest={"ImageSegmentationPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py index 60b9bee8c8..c1f8f9f3bc 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.params', - manifest={ - 'VideoActionRecognitionPredictionParams', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.params", + manifest={"VideoActionRecognitionPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py index f90d338919..1b8d84a7d1 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.params', - manifest={ - 'VideoClassificationPredictionParams', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.params", + manifest={"VideoClassificationPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py index 7c92def8fc..4c0b6846bc 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.params', - manifest={ - 'VideoObjectTrackingPredictionParams', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.params", + manifest={"VideoObjectTrackingPredictionParams",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py index 5041ec8e6f..d5f2762504 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py @@ -15,26 +15,46 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.classification import ClassificationPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_object_detection import ImageObjectDetectionPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_segmentation import ImageSegmentationPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_classification import TabularClassificationPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_regression import TabularRegressionPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_extraction import TextExtractionPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_sentiment import TextSentimentPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_action_recognition import VideoActionRecognitionPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_classification import VideoClassificationPredictionResult -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_object_tracking import VideoObjectTrackingPredictionResult +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.classification import ( + ClassificationPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_object_detection import ( + ImageObjectDetectionPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.image_segmentation import ( + ImageSegmentationPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_classification import ( + TabularClassificationPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.tabular_regression import ( + TabularRegressionPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_extraction import ( + TextExtractionPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_sentiment import ( + TextSentimentPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_action_recognition import ( + VideoActionRecognitionPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_classification import ( + VideoClassificationPredictionResult, +) +from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_object_tracking import ( + VideoObjectTrackingPredictionResult, +) __all__ = ( - 'ClassificationPredictionResult', - 'ImageObjectDetectionPredictionResult', - 'ImageSegmentationPredictionResult', - 'TabularClassificationPredictionResult', - 'TabularRegressionPredictionResult', - 'TextExtractionPredictionResult', - 'TextSentimentPredictionResult', - 'VideoActionRecognitionPredictionResult', - 'VideoClassificationPredictionResult', - 'VideoObjectTrackingPredictionResult', + "ClassificationPredictionResult", + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py index 42f26f575f..91fae5a3b1 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py @@ -28,14 +28,14 @@ __all__ = ( - 'ImageObjectDetectionPredictionResult', - 'ImageSegmentationPredictionResult', - 'TabularClassificationPredictionResult', - 'TabularRegressionPredictionResult', - 'TextExtractionPredictionResult', - 'TextSentimentPredictionResult', - 'VideoActionRecognitionPredictionResult', - 'VideoClassificationPredictionResult', - 'VideoObjectTrackingPredictionResult', -'ClassificationPredictionResult', + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", + "ClassificationPredictionResult", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py index 019d5ea59c..a0fd2058e0 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py @@ -15,46 +15,26 @@ # limitations under the License. # -from .classification import ( - ClassificationPredictionResult, -) -from .image_object_detection import ( - ImageObjectDetectionPredictionResult, -) -from .image_segmentation import ( - ImageSegmentationPredictionResult, -) -from .tabular_classification import ( - TabularClassificationPredictionResult, -) -from .tabular_regression import ( - TabularRegressionPredictionResult, -) -from .text_extraction import ( - TextExtractionPredictionResult, -) -from .text_sentiment import ( - TextSentimentPredictionResult, -) -from .video_action_recognition import ( - VideoActionRecognitionPredictionResult, -) -from .video_classification import ( - VideoClassificationPredictionResult, -) -from .video_object_tracking import ( - VideoObjectTrackingPredictionResult, -) +from .classification import ClassificationPredictionResult +from .image_object_detection import ImageObjectDetectionPredictionResult +from .image_segmentation import ImageSegmentationPredictionResult +from .tabular_classification import TabularClassificationPredictionResult +from .tabular_regression import TabularRegressionPredictionResult +from .text_extraction import TextExtractionPredictionResult +from .text_sentiment import TextSentimentPredictionResult +from .video_action_recognition import VideoActionRecognitionPredictionResult +from .video_classification import VideoClassificationPredictionResult +from .video_object_tracking import VideoObjectTrackingPredictionResult __all__ = ( - 'ClassificationPredictionResult', - 'ImageObjectDetectionPredictionResult', - 'ImageSegmentationPredictionResult', - 'TabularClassificationPredictionResult', - 'TabularRegressionPredictionResult', - 'TextExtractionPredictionResult', - 'TextSentimentPredictionResult', - 'VideoActionRecognitionPredictionResult', - 'VideoClassificationPredictionResult', - 'VideoObjectTrackingPredictionResult', + "ClassificationPredictionResult", + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py index ed4bcece4f..3bfe82f64e 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'ClassificationPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"ClassificationPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py index f125a9d4a6..3d0f7f1f76 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py @@ -22,10 +22,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'ImageObjectDetectionPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"ImageObjectDetectionPredictionResult",}, ) @@ -60,9 +58,7 @@ class ImageObjectDetectionPredictionResult(proto.Message): confidences = proto.RepeatedField(proto.FLOAT, number=3) - bboxes = proto.RepeatedField(proto.MESSAGE, number=4, - message=struct.ListValue, - ) + bboxes = proto.RepeatedField(proto.MESSAGE, number=4, message=struct.ListValue,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py index abc5977b79..ffd6fb9380 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'ImageSegmentationPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"ImageSegmentationPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py index bd373e8e8d..4906ad59a5 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'TabularClassificationPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"TabularClassificationPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py index bc21aaaf8d..71d535c1f0 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'TabularRegressionPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"TabularRegressionPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py index e23faf278f..e3c10b5d75 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'TextExtractionPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"TextExtractionPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py index 9a822e7782..f31b95a18f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'TextSentimentPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"TextSentimentPredictionResult",}, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py index 6b70a6c36c..99fa365b47 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py @@ -23,10 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'VideoActionRecognitionPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"VideoActionRecognitionPredictionResult",}, ) @@ -64,17 +62,13 @@ class VideoActionRecognitionPredictionResult(proto.Message): display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field(proto.MESSAGE, number=4, - message=duration.Duration, + time_segment_start = proto.Field( + proto.MESSAGE, number=4, message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, - message=duration.Duration, - ) + time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) - confidence = proto.Field(proto.MESSAGE, number=6, - message=wrappers.FloatValue, - ) + confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py index 2b435bbff8..3fca68fe64 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py @@ -23,10 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'VideoClassificationPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"VideoClassificationPredictionResult",}, ) @@ -80,17 +78,13 @@ class VideoClassificationPredictionResult(proto.Message): type_ = proto.Field(proto.STRING, number=3) - time_segment_start = proto.Field(proto.MESSAGE, number=4, - message=duration.Duration, + time_segment_start = proto.Field( + proto.MESSAGE, number=4, message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=5, - message=duration.Duration, - ) + time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) - confidence = proto.Field(proto.MESSAGE, number=6, - message=wrappers.FloatValue, - ) + confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py index 2bbf98710c..6fd431c0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py @@ -23,10 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.predict.prediction', - manifest={ - 'VideoObjectTrackingPredictionResult', - }, + package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", + manifest={"VideoObjectTrackingPredictionResult",}, ) @@ -64,6 +62,7 @@ class VideoObjectTrackingPredictionResult(proto.Message): bounding boxes in the frames identify the same object. """ + class Frame(proto.Message): r"""The fields ``xMin``, ``xMax``, ``yMin``, and ``yMax`` refer to a bounding box, i.e. the rectangle over the video frame pinpointing @@ -88,45 +87,29 @@ class Frame(proto.Message): box. """ - time_offset = proto.Field(proto.MESSAGE, number=1, - message=duration.Duration, - ) + time_offset = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) - x_min = proto.Field(proto.MESSAGE, number=2, - message=wrappers.FloatValue, - ) + x_min = proto.Field(proto.MESSAGE, number=2, message=wrappers.FloatValue,) - x_max = proto.Field(proto.MESSAGE, number=3, - message=wrappers.FloatValue, - ) + x_max = proto.Field(proto.MESSAGE, number=3, message=wrappers.FloatValue,) - y_min = proto.Field(proto.MESSAGE, number=4, - message=wrappers.FloatValue, - ) + y_min = proto.Field(proto.MESSAGE, number=4, message=wrappers.FloatValue,) - y_max = proto.Field(proto.MESSAGE, number=5, - message=wrappers.FloatValue, - ) + y_max = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) id = proto.Field(proto.STRING, number=1) display_name = proto.Field(proto.STRING, number=2) - time_segment_start = proto.Field(proto.MESSAGE, number=3, - message=duration.Duration, + time_segment_start = proto.Field( + proto.MESSAGE, number=3, message=duration.Duration, ) - time_segment_end = proto.Field(proto.MESSAGE, number=4, - message=duration.Duration, - ) + time_segment_end = proto.Field(proto.MESSAGE, number=4, message=duration.Duration,) - confidence = proto.Field(proto.MESSAGE, number=5, - message=wrappers.FloatValue, - ) + confidence = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) - frames = proto.RepeatedField(proto.MESSAGE, number=6, - message=Frame, - ) + frames = proto.RepeatedField(proto.MESSAGE, number=6, message=Frame,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py index 9475d2c67c..d632ef9609 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py @@ -15,56 +15,106 @@ # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import AutoMlImageClassification -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import AutoMlImageClassificationInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import AutoMlImageClassificationMetadata -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import AutoMlImageObjectDetection -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import AutoMlImageObjectDetectionInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import AutoMlImageObjectDetectionMetadata -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import AutoMlImageSegmentation -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import AutoMlImageSegmentationInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import AutoMlImageSegmentationMetadata -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import AutoMlTables -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import AutoMlTablesInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import AutoMlTablesMetadata -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import AutoMlTextClassification -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import AutoMlTextClassificationInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import AutoMlTextExtraction -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import AutoMlTextExtractionInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import AutoMlTextSentiment -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import AutoMlTextSentimentInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import AutoMlVideoActionRecognition -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import AutoMlVideoActionRecognitionInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import AutoMlVideoClassification -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import AutoMlVideoClassificationInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import AutoMlVideoObjectTracking -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import AutoMlVideoObjectTrackingInputs -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import ( + AutoMlImageClassification, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import ( + AutoMlImageClassificationInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import ( + AutoMlImageClassificationMetadata, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import ( + AutoMlImageObjectDetection, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import ( + AutoMlImageObjectDetectionInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_object_detection import ( + AutoMlImageObjectDetectionMetadata, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import ( + AutoMlImageSegmentation, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import ( + AutoMlImageSegmentationInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_segmentation import ( + AutoMlImageSegmentationMetadata, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import ( + AutoMlTables, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import ( + AutoMlTablesInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_tables import ( + AutoMlTablesMetadata, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import ( + AutoMlTextClassification, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_classification import ( + AutoMlTextClassificationInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import ( + AutoMlTextExtraction, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_extraction import ( + AutoMlTextExtractionInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import ( + AutoMlTextSentiment, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import ( + AutoMlTextSentimentInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import ( + AutoMlVideoActionRecognition, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import ( + AutoMlVideoActionRecognitionInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import ( + AutoMlVideoClassification, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_classification import ( + AutoMlVideoClassificationInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import ( + AutoMlVideoObjectTracking, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_object_tracking import ( + AutoMlVideoObjectTrackingInputs, +) +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.export_evaluated_data_items_config import ( + ExportEvaluatedDataItemsConfig, +) __all__ = ( - 'AutoMlImageClassification', - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - 'ExportEvaluatedDataItemsConfig', + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py index 135e04f228..34958e5add 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py @@ -43,29 +43,29 @@ __all__ = ( - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - 'ExportEvaluatedDataItemsConfig', -'AutoMlImageClassification', + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", + "AutoMlImageClassification", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py index 2d7d19c057..a15aa2c041 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py @@ -59,34 +59,32 @@ AutoMlVideoObjectTracking, AutoMlVideoObjectTrackingInputs, ) -from .export_evaluated_data_items_config import ( - ExportEvaluatedDataItemsConfig, -) +from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig __all__ = ( - 'AutoMlImageClassification', - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - 'ExportEvaluatedDataItemsConfig', + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py index 6eb4ada23e..8ee27076d2 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", manifest={ - 'AutoMlImageClassification', - 'AutoMlImageClassificationInputs', - 'AutoMlImageClassificationMetadata', + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", }, ) @@ -39,12 +39,12 @@ class AutoMlImageClassification(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlImageClassificationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageClassificationInputs", ) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlImageClassificationMetadata', + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageClassificationMetadata", ) @@ -92,6 +92,7 @@ class AutoMlImageClassificationInputs(proto.Message): be trained (i.e. assuming that for each image multiple annotations may be applicable). """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -100,9 +101,7 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 3 MOBILE_TF_HIGH_ACCURACY_1 = 4 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) base_model_id = proto.Field(proto.STRING, number=2) @@ -127,6 +126,7 @@ class AutoMlImageClassificationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ + class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -135,8 +135,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field(proto.ENUM, number=2, - enum=SuccessfulStopReason, + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py index 6cd9a9684d..512e35ed1d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", manifest={ - 'AutoMlImageObjectDetection', - 'AutoMlImageObjectDetectionInputs', - 'AutoMlImageObjectDetectionMetadata', + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", }, ) @@ -39,12 +39,12 @@ class AutoMlImageObjectDetection(proto.Message): The metadata information """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlImageObjectDetectionInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageObjectDetectionInputs", ) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlImageObjectDetectionMetadata', + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageObjectDetectionMetadata", ) @@ -80,6 +80,7 @@ class AutoMlImageObjectDetectionInputs(proto.Message): training before the entire training budget has been used. """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -89,9 +90,7 @@ class ModelType(proto.Enum): MOBILE_TF_VERSATILE_1 = 4 MOBILE_TF_HIGH_ACCURACY_1 = 5 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -112,6 +111,7 @@ class AutoMlImageObjectDetectionMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ + class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -120,8 +120,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field(proto.ENUM, number=2, - enum=SuccessfulStopReason, + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py index 28fd9d385d..014df43b2f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py @@ -19,11 +19,11 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", manifest={ - 'AutoMlImageSegmentation', - 'AutoMlImageSegmentationInputs', - 'AutoMlImageSegmentationMetadata', + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", }, ) @@ -39,12 +39,12 @@ class AutoMlImageSegmentation(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlImageSegmentationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageSegmentationInputs", ) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlImageSegmentationMetadata', + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageSegmentationMetadata", ) @@ -76,6 +76,7 @@ class AutoMlImageSegmentationInputs(proto.Message): ``base`` model must be in the same Project and Location as the new Model to train, and have the same modelType. """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -83,9 +84,7 @@ class ModelType(proto.Enum): CLOUD_LOW_ACCURACY_1 = 2 MOBILE_TF_LOW_LATENCY_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) budget_milli_node_hours = proto.Field(proto.INT64, number=2) @@ -106,6 +105,7 @@ class AutoMlImageSegmentationMetadata(proto.Message): For successful job completions, this is the reason why the job has finished. """ + class SuccessfulStopReason(proto.Enum): r"""""" SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 @@ -114,8 +114,8 @@ class SuccessfulStopReason(proto.Enum): cost_milli_node_hours = proto.Field(proto.INT64, number=1) - successful_stop_reason = proto.Field(proto.ENUM, number=2, - enum=SuccessfulStopReason, + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py index a506fe6493..19c43929e8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py @@ -18,16 +18,14 @@ import proto # type: ignore -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types import export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config +from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types import ( + export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config, +) __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'AutoMlTables', - 'AutoMlTablesInputs', - 'AutoMlTablesMetadata', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata",}, ) @@ -41,13 +39,9 @@ class AutoMlTables(proto.Message): The metadata information. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTablesInputs', - ) + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTablesInputs",) - metadata = proto.Field(proto.MESSAGE, number=2, - message='AutoMlTablesMetadata', - ) + metadata = proto.Field(proto.MESSAGE, number=2, message="AutoMlTablesMetadata",) class AutoMlTablesInputs(proto.Message): @@ -152,6 +146,7 @@ class AutoMlTablesInputs(proto.Message): configuration is absent, then the export is not performed. """ + class Transformation(proto.Message): r""" @@ -173,6 +168,7 @@ class Transformation(proto.Message): repeated_text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation): """ + class AutoTransformation(proto.Message): r"""Training pipeline will infer the proper transformation based on the statistic of dataset. @@ -347,48 +343,76 @@ class TextArrayTransformation(proto.Message): column_name = proto.Field(proto.STRING, number=1) - auto = proto.Field(proto.MESSAGE, number=1, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.AutoTransformation', + auto = proto.Field( + proto.MESSAGE, + number=1, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.AutoTransformation", ) - numeric = proto.Field(proto.MESSAGE, number=2, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.NumericTransformation', + numeric = proto.Field( + proto.MESSAGE, + number=2, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.NumericTransformation", ) - categorical = proto.Field(proto.MESSAGE, number=3, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.CategoricalTransformation', + categorical = proto.Field( + proto.MESSAGE, + number=3, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.CategoricalTransformation", ) - timestamp = proto.Field(proto.MESSAGE, number=4, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.TimestampTransformation', + timestamp = proto.Field( + proto.MESSAGE, + number=4, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TimestampTransformation", ) - text = proto.Field(proto.MESSAGE, number=5, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.TextTransformation', + text = proto.Field( + proto.MESSAGE, + number=5, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TextTransformation", ) - repeated_numeric = proto.Field(proto.MESSAGE, number=6, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.NumericArrayTransformation', + repeated_numeric = proto.Field( + proto.MESSAGE, + number=6, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.NumericArrayTransformation", ) - repeated_categorical = proto.Field(proto.MESSAGE, number=7, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.CategoricalArrayTransformation', + repeated_categorical = proto.Field( + proto.MESSAGE, + number=7, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.CategoricalArrayTransformation", ) - repeated_text = proto.Field(proto.MESSAGE, number=8, oneof='transformation_detail', - message='AutoMlTablesInputs.Transformation.TextArrayTransformation', + repeated_text = proto.Field( + proto.MESSAGE, + number=8, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TextArrayTransformation", ) - optimization_objective_recall_value = proto.Field(proto.FLOAT, number=5, oneof='additional_optimization_objective_config') + optimization_objective_recall_value = proto.Field( + proto.FLOAT, number=5, oneof="additional_optimization_objective_config" + ) - optimization_objective_precision_value = proto.Field(proto.FLOAT, number=6, oneof='additional_optimization_objective_config') + optimization_objective_precision_value = proto.Field( + proto.FLOAT, number=6, oneof="additional_optimization_objective_config" + ) prediction_type = proto.Field(proto.STRING, number=1) target_column = proto.Field(proto.STRING, number=2) - transformations = proto.RepeatedField(proto.MESSAGE, number=3, - message=Transformation, + transformations = proto.RepeatedField( + proto.MESSAGE, number=3, message=Transformation, ) optimization_objective = proto.Field(proto.STRING, number=4) @@ -399,7 +423,9 @@ class TextArrayTransformation(proto.Message): weight_column_name = proto.Field(proto.STRING, number=9) - export_evaluated_data_items_config = proto.Field(proto.MESSAGE, number=10, + export_evaluated_data_items_config = proto.Field( + proto.MESSAGE, + number=10, message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py index dd9c448258..9fe6b865c9 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'AutoMlTextClassification', - 'AutoMlTextClassificationInputs', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"AutoMlTextClassification", "AutoMlTextClassificationInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlTextClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTextClassificationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlTextClassificationInputs", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py index d1111f379f..c7b1fc6dba 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'AutoMlTextExtraction', - 'AutoMlTextExtractionInputs', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"AutoMlTextExtraction", "AutoMlTextExtractionInputs",}, ) @@ -36,9 +33,7 @@ class AutoMlTextExtraction(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTextExtractionInputs', - ) + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextExtractionInputs",) class AutoMlTextExtractionInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py index 06f4fa06f9..8239b55fdf 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'AutoMlTextSentiment', - 'AutoMlTextSentimentInputs', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"AutoMlTextSentiment", "AutoMlTextSentimentInputs",}, ) @@ -36,9 +33,7 @@ class AutoMlTextSentiment(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlTextSentimentInputs', - ) + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextSentimentInputs",) class AutoMlTextSentimentInputs(proto.Message): diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py index e795fa10c5..66448faf01 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'AutoMlVideoActionRecognition', - 'AutoMlVideoActionRecognitionInputs', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"AutoMlVideoActionRecognition", "AutoMlVideoActionRecognitionInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlVideoActionRecognition(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlVideoActionRecognitionInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoActionRecognitionInputs", ) @@ -48,15 +45,14 @@ class AutoMlVideoActionRecognitionInputs(proto.Message): model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoActionRecognitionInputs.ModelType): """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 CLOUD = 1 MOBILE_VERSATILE_1 = 2 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py index 2d3ffbf007..e1c12eb46c 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'AutoMlVideoClassification', - 'AutoMlVideoClassificationInputs', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"AutoMlVideoClassification", "AutoMlVideoClassificationInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlVideoClassification(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlVideoClassificationInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoClassificationInputs", ) @@ -48,6 +45,7 @@ class AutoMlVideoClassificationInputs(proto.Message): model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoClassificationInputs.ModelType): """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -55,9 +53,7 @@ class ModelType(proto.Enum): MOBILE_VERSATILE_1 = 2 MOBILE_JETSON_VERSATILE_1 = 3 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py index adf69eee56..328e266a3b 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'AutoMlVideoObjectTracking', - 'AutoMlVideoObjectTrackingInputs', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs",}, ) @@ -36,8 +33,8 @@ class AutoMlVideoObjectTracking(proto.Message): The input parameters of this TrainingJob. """ - inputs = proto.Field(proto.MESSAGE, number=1, - message='AutoMlVideoObjectTrackingInputs', + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoObjectTrackingInputs", ) @@ -48,6 +45,7 @@ class AutoMlVideoObjectTrackingInputs(proto.Message): model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoObjectTrackingInputs.ModelType): """ + class ModelType(proto.Enum): r"""""" MODEL_TYPE_UNSPECIFIED = 0 @@ -58,9 +56,7 @@ class ModelType(proto.Enum): MOBILE_JETSON_VERSATILE_1 = 5 MOBILE_JETSON_LOW_LATENCY_1 = 6 - model_type = proto.Field(proto.ENUM, number=1, - enum=ModelType, - ) + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py index 2770d78441..9a6195fec2 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1.schema.trainingjob.definition', - manifest={ - 'ExportEvaluatedDataItemsConfig', - }, + package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + manifest={"ExportEvaluatedDataItemsConfig",}, ) diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index 24c5acb6bb..1b0c76e834 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -180,166 +180,166 @@ __all__ = ( - 'AcceleratorType', - 'ActiveLearningConfig', - 'Annotation', - 'AnnotationSpec', - 'AutomaticResources', - 'BatchDedicatedResources', - 'BatchMigrateResourcesOperationMetadata', - 'BatchMigrateResourcesRequest', - 'BatchMigrateResourcesResponse', - 'BatchPredictionJob', - 'BigQueryDestination', - 'BigQuerySource', - 'CancelBatchPredictionJobRequest', - 'CancelCustomJobRequest', - 'CancelDataLabelingJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CancelTrainingPipelineRequest', - 'CompletionStats', - 'ContainerRegistryDestination', - 'ContainerSpec', - 'CreateBatchPredictionJobRequest', - 'CreateCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'CreateDatasetOperationMetadata', - 'CreateDatasetRequest', - 'CreateEndpointOperationMetadata', - 'CreateEndpointRequest', - 'CreateHyperparameterTuningJobRequest', - 'CreateSpecialistPoolOperationMetadata', - 'CreateSpecialistPoolRequest', - 'CreateTrainingPipelineRequest', - 'CustomJob', - 'CustomJobSpec', - 'DataItem', - 'DataLabelingJob', - 'Dataset', - 'DatasetServiceClient', - 'DedicatedResources', - 'DeleteBatchPredictionJobRequest', - 'DeleteCustomJobRequest', - 'DeleteDataLabelingJobRequest', - 'DeleteDatasetRequest', - 'DeleteEndpointRequest', - 'DeleteHyperparameterTuningJobRequest', - 'DeleteModelRequest', - 'DeleteOperationMetadata', - 'DeleteSpecialistPoolRequest', - 'DeleteTrainingPipelineRequest', - 'DeployModelOperationMetadata', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployedModel', - 'DeployedModelRef', - 'DiskSpec', - 'EncryptionSpec', - 'Endpoint', - 'EndpointServiceClient', - 'EnvVar', - 'ExportDataConfig', - 'ExportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportModelOperationMetadata', - 'ExportModelRequest', - 'ExportModelResponse', - 'FilterSplit', - 'FractionSplit', - 'GcsDestination', - 'GcsSource', - 'GenericOperationMetadata', - 'GetAnnotationSpecRequest', - 'GetBatchPredictionJobRequest', - 'GetCustomJobRequest', - 'GetDataLabelingJobRequest', - 'GetDatasetRequest', - 'GetEndpointRequest', - 'GetHyperparameterTuningJobRequest', - 'GetModelEvaluationRequest', - 'GetModelEvaluationSliceRequest', - 'GetModelRequest', - 'GetSpecialistPoolRequest', - 'GetTrainingPipelineRequest', - 'HyperparameterTuningJob', - 'ImportDataConfig', - 'ImportDataOperationMetadata', - 'ImportDataRequest', - 'ImportDataResponse', - 'InputDataConfig', - 'JobServiceClient', - 'JobState', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'ListModelsRequest', - 'ListModelsResponse', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'MachineSpec', - 'ManualBatchTuningParameters', - 'Measurement', - 'MigratableResource', - 'MigrateResourceRequest', - 'MigrateResourceResponse', - 'MigrationServiceClient', - 'Model', - 'ModelContainerSpec', - 'ModelEvaluation', - 'ModelEvaluationSlice', - 'ModelServiceClient', - 'PipelineServiceClient', - 'PipelineState', - 'Port', - 'PredefinedSplit', - 'PredictRequest', - 'PredictResponse', - 'PredictSchemata', - 'PredictionServiceClient', - 'PythonPackageSpec', - 'ResourcesConsumed', - 'SampleConfig', - 'Scheduling', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'SpecialistPool', - 'StudySpec', - 'TimestampSplit', - 'TrainingConfig', - 'TrainingPipeline', - 'Trial', - 'UndeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UpdateDatasetRequest', - 'UpdateEndpointRequest', - 'UpdateModelRequest', - 'UpdateSpecialistPoolOperationMetadata', - 'UpdateSpecialistPoolRequest', - 'UploadModelOperationMetadata', - 'UploadModelRequest', - 'UploadModelResponse', - 'UserActionReference', - 'WorkerPoolSpec', -'SpecialistPoolServiceClient', + "AcceleratorType", + "ActiveLearningConfig", + "Annotation", + "AnnotationSpec", + "AutomaticResources", + "BatchDedicatedResources", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "BatchPredictionJob", + "BigQueryDestination", + "BigQuerySource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CancelTrainingPipelineRequest", + "CompletionStats", + "ContainerRegistryDestination", + "ContainerSpec", + "CreateBatchPredictionJobRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "CreateHyperparameterTuningJobRequest", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "CreateTrainingPipelineRequest", + "CustomJob", + "CustomJobSpec", + "DataItem", + "DataLabelingJob", + "Dataset", + "DatasetServiceClient", + "DedicatedResources", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteDatasetRequest", + "DeleteEndpointRequest", + "DeleteHyperparameterTuningJobRequest", + "DeleteModelRequest", + "DeleteOperationMetadata", + "DeleteSpecialistPoolRequest", + "DeleteTrainingPipelineRequest", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "DeployedModel", + "DeployedModelRef", + "DiskSpec", + "EncryptionSpec", + "Endpoint", + "EndpointServiceClient", + "EnvVar", + "ExportDataConfig", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "FilterSplit", + "FractionSplit", + "GcsDestination", + "GcsSource", + "GenericOperationMetadata", + "GetAnnotationSpecRequest", + "GetBatchPredictionJobRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetDatasetRequest", + "GetEndpointRequest", + "GetHyperparameterTuningJobRequest", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "GetSpecialistPoolRequest", + "GetTrainingPipelineRequest", + "HyperparameterTuningJob", + "ImportDataConfig", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "InputDataConfig", + "JobServiceClient", + "JobState", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "ListEndpointsRequest", + "ListEndpointsResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "MachineSpec", + "ManualBatchTuningParameters", + "Measurement", + "MigratableResource", + "MigrateResourceRequest", + "MigrateResourceResponse", + "MigrationServiceClient", + "Model", + "ModelContainerSpec", + "ModelEvaluation", + "ModelEvaluationSlice", + "ModelServiceClient", + "PipelineServiceClient", + "PipelineState", + "Port", + "PredefinedSplit", + "PredictRequest", + "PredictResponse", + "PredictSchemata", + "PredictionServiceClient", + "PythonPackageSpec", + "ResourcesConsumed", + "SampleConfig", + "Scheduling", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "SpecialistPool", + "StudySpec", + "TimestampSplit", + "TrainingConfig", + "TrainingPipeline", + "Trial", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateDatasetRequest", + "UpdateEndpointRequest", + "UpdateModelRequest", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "UserActionReference", + "WorkerPoolSpec", + "SpecialistPoolServiceClient", ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py index 9d1f004f6a..597f654cb9 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import DatasetServiceAsyncClient __all__ = ( - 'DatasetServiceClient', - 'DatasetServiceAsyncClient', + "DatasetServiceClient", + "DatasetServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index 950d920c5a..a07ee32dfd 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,26 +60,42 @@ class DatasetServiceAsyncClient: annotation_path = staticmethod(DatasetServiceClient.annotation_path) parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) - parse_annotation_spec_path = staticmethod(DatasetServiceClient.parse_annotation_spec_path) + parse_annotation_spec_path = staticmethod( + DatasetServiceClient.parse_annotation_spec_path + ) data_item_path = staticmethod(DatasetServiceClient.data_item_path) parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) dataset_path = staticmethod(DatasetServiceClient.dataset_path) parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) - common_billing_account_path = staticmethod(DatasetServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(DatasetServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + DatasetServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + DatasetServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(DatasetServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + DatasetServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(DatasetServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(DatasetServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + DatasetServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + DatasetServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(DatasetServiceClient.common_project_path) - parse_common_project_path = staticmethod(DatasetServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + DatasetServiceClient.parse_common_project_path + ) common_location_path = staticmethod(DatasetServiceClient.common_location_path) - parse_common_location_path = staticmethod(DatasetServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + DatasetServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -122,14 +138,18 @@ def transport(self) -> DatasetServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient)) + get_transport_class = functools.partial( + type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -168,18 +188,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a Dataset. Args: @@ -220,8 +240,10 @@ async def create_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.CreateDatasetRequest(request) @@ -244,18 +266,11 @@ async def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -268,14 +283,15 @@ async def create_dataset(self, # Done; return the response. return response - async def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + async def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -307,8 +323,10 @@ async def get_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.GetDatasetRequest(request) @@ -329,31 +347,25 @@ async def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + async def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -398,8 +410,10 @@ async def update_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.UpdateDatasetRequest(request) @@ -422,30 +436,26 @@ async def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsAsyncPager: + async def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: r"""Lists Datasets in a Location. Args: @@ -480,8 +490,10 @@ async def list_datasets(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListDatasetsRequest(request) @@ -502,39 +514,30 @@ async def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDatasetsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Dataset. Args: @@ -580,8 +583,10 @@ async def delete_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.DeleteDatasetRequest(request) @@ -602,18 +607,11 @@ async def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -626,15 +624,16 @@ async def delete_dataset(self, # Done; return the response. return response - async def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Imports data into a Dataset. Args: @@ -678,8 +677,10 @@ async def import_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ImportDataRequest(request) @@ -703,18 +704,11 @@ async def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -727,15 +721,16 @@ async def import_data(self, # Done; return the response. return response - async def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports data from a Dataset. Args: @@ -778,8 +773,10 @@ async def export_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ExportDataRequest(request) @@ -802,18 +799,11 @@ async def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -826,14 +816,15 @@ async def export_data(self, # Done; return the response. return response - async def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsAsyncPager: + async def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: r"""Lists DataItems in a Dataset. Args: @@ -869,8 +860,10 @@ async def list_data_items(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListDataItemsRequest(request) @@ -891,39 +884,30 @@ async def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataItemsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + async def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -957,8 +941,10 @@ async def get_annotation_spec(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.GetAnnotationSpecRequest(request) @@ -979,30 +965,24 @@ async def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsAsyncPager: + async def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1039,8 +1019,10 @@ async def list_annotations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListAnnotationsRequest(request) @@ -1061,47 +1043,30 @@ async def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListAnnotationsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceAsyncClient', -) +__all__ = ("DatasetServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index 52109ac90b..160a2049b8 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,13 +60,14 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry['grpc'] = DatasetServiceGrpcTransport - _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[DatasetServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry["grpc"] = DatasetServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +118,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +153,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: DatasetServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,110 +169,149 @@ def transport(self) -> DatasetServiceTransport: return self._transport @staticmethod - def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: + def annotation_path( + project: str, location: str, dataset: str, data_item: str, annotation: str, + ) -> str: """Return a fully-qualified annotation string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) @staticmethod - def parse_annotation_path(path: str) -> Dict[str,str]: + def parse_annotation_path(path: str) -> Dict[str, str]: """Parse a annotation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: + def annotation_spec_path( + project: str, location: str, dataset: str, annotation_spec: str, + ) -> str: """Return a fully-qualified annotation_spec string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) @staticmethod - def parse_annotation_spec_path(path: str) -> Dict[str,str]: + def parse_annotation_spec_path(path: str) -> Dict[str, str]: """Parse a annotation_spec path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: + def data_item_path( + project: str, location: str, dataset: str, data_item: str, + ) -> str: """Return a fully-qualified data_item string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) @staticmethod - def parse_data_item_path(path: str) -> Dict[str,str]: + def parse_data_item_path(path: str) -> Dict[str, str]: """Parse a data_item path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -316,7 +355,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -326,7 +367,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -338,7 +381,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -350,8 +395,10 @@ def __init__(self, *, if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -370,15 +417,16 @@ def __init__(self, *, client_info=client_info, ) - def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a Dataset. Args: @@ -419,8 +467,10 @@ def create_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.CreateDatasetRequest. @@ -444,18 +494,11 @@ def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -468,14 +511,15 @@ def create_dataset(self, # Done; return the response. return response - def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -507,8 +551,10 @@ def get_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetDatasetRequest. @@ -530,31 +576,25 @@ def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -599,8 +639,10 @@ def update_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.UpdateDatasetRequest. @@ -624,30 +666,26 @@ def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -682,8 +720,10 @@ def list_datasets(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDatasetsRequest. @@ -705,39 +745,30 @@ def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Dataset. Args: @@ -783,8 +814,10 @@ def delete_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.DeleteDatasetRequest. @@ -806,18 +839,11 @@ def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -830,15 +856,16 @@ def delete_dataset(self, # Done; return the response. return response - def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Imports data into a Dataset. Args: @@ -882,8 +909,10 @@ def import_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ImportDataRequest. @@ -907,18 +936,11 @@ def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -931,15 +953,16 @@ def import_data(self, # Done; return the response. return response - def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports data from a Dataset. Args: @@ -982,8 +1005,10 @@ def export_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ExportDataRequest. @@ -1007,18 +1032,11 @@ def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1031,14 +1049,15 @@ def export_data(self, # Done; return the response. return response - def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -1074,8 +1093,10 @@ def list_data_items(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDataItemsRequest. @@ -1097,39 +1118,30 @@ def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -1163,8 +1175,10 @@ def get_annotation_spec(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetAnnotationSpecRequest. @@ -1186,30 +1200,24 @@ def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1246,8 +1254,10 @@ def list_annotations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListAnnotationsRequest. @@ -1269,47 +1279,30 @@ def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceClient', -) +__all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py index 3439dc331c..c3f8265b6e 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import annotation from google.cloud.aiplatform_v1.types import data_item @@ -40,12 +49,15 @@ class ListDatasetsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListDatasetsResponse], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListDatasetsResponse], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -79,7 +91,7 @@ def __iter__(self) -> Iterable[dataset.Dataset]: yield from page.datasets def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDatasetsAsyncPager: @@ -99,12 +111,15 @@ class ListDatasetsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -142,7 +157,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataItemsPager: @@ -162,12 +177,15 @@ class ListDataItemsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListDataItemsResponse], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListDataItemsResponse], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -201,7 +219,7 @@ def __iter__(self) -> Iterable[data_item.DataItem]: yield from page.data_items def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataItemsAsyncPager: @@ -221,12 +239,15 @@ class ListDataItemsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -264,7 +285,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListAnnotationsPager: @@ -284,12 +305,15 @@ class ListAnnotationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListAnnotationsResponse], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListAnnotationsResponse], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -323,7 +347,7 @@ def __iter__(self) -> Iterable[annotation.Annotation]: yield from page.annotations def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListAnnotationsAsyncPager: @@ -343,12 +367,15 @@ class ListAnnotationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -386,4 +413,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py index 5f02a0f0d9..a4461d2ced 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] -_transport_registry['grpc'] = DatasetServiceGrpcTransport -_transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = DatasetServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport __all__ = ( - 'DatasetServiceTransport', - 'DatasetServiceGrpcTransport', - 'DatasetServiceGrpcAsyncIOTransport', + "DatasetServiceTransport", + "DatasetServiceGrpcTransport", + "DatasetServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py index 15daeb6369..bf2165e7af 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,8 +81,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -91,17 +91,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -110,56 +112,35 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_dataset: gapic_v1.method.wrap_method( - self.create_dataset, - default_timeout=None, - client_info=client_info, + self.create_dataset, default_timeout=None, client_info=client_info, ), self.get_dataset: gapic_v1.method.wrap_method( - self.get_dataset, - default_timeout=None, - client_info=client_info, + self.get_dataset, default_timeout=None, client_info=client_info, ), self.update_dataset: gapic_v1.method.wrap_method( - self.update_dataset, - default_timeout=None, - client_info=client_info, + self.update_dataset, default_timeout=None, client_info=client_info, ), self.list_datasets: gapic_v1.method.wrap_method( - self.list_datasets, - default_timeout=None, - client_info=client_info, + self.list_datasets, default_timeout=None, client_info=client_info, ), self.delete_dataset: gapic_v1.method.wrap_method( - self.delete_dataset, - default_timeout=None, - client_info=client_info, + self.delete_dataset, default_timeout=None, client_info=client_info, ), self.import_data: gapic_v1.method.wrap_method( - self.import_data, - default_timeout=None, - client_info=client_info, + self.import_data, default_timeout=None, client_info=client_info, ), self.export_data: gapic_v1.method.wrap_method( - self.export_data, - default_timeout=None, - client_info=client_info, + self.export_data, default_timeout=None, client_info=client_info, ), self.list_data_items: gapic_v1.method.wrap_method( - self.list_data_items, - default_timeout=None, - client_info=client_info, + self.list_data_items, default_timeout=None, client_info=client_info, ), self.get_annotation_spec: gapic_v1.method.wrap_method( - self.get_annotation_spec, - default_timeout=None, - client_info=client_info, + self.get_annotation_spec, default_timeout=None, client_info=client_info, ), self.list_annotations: gapic_v1.method.wrap_method( - self.list_annotations, - default_timeout=None, - client_info=client_info, + self.list_annotations, default_timeout=None, client_info=client_info, ), - } @property @@ -168,96 +149,106 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_dataset(self) -> typing.Callable[ - [dataset_service.CreateDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_dataset( + self, + ) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_dataset(self) -> typing.Callable[ - [dataset_service.GetDatasetRequest], - typing.Union[ - dataset.Dataset, - typing.Awaitable[dataset.Dataset] - ]]: + def get_dataset( + self, + ) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], + ]: raise NotImplementedError() @property - def update_dataset(self) -> typing.Callable[ - [dataset_service.UpdateDatasetRequest], - typing.Union[ - gca_dataset.Dataset, - typing.Awaitable[gca_dataset.Dataset] - ]]: + def update_dataset( + self, + ) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], + ]: raise NotImplementedError() @property - def list_datasets(self) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], - typing.Union[ - dataset_service.ListDatasetsResponse, - typing.Awaitable[dataset_service.ListDatasetsResponse] - ]]: + def list_datasets( + self, + ) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse], + ], + ]: raise NotImplementedError() @property - def delete_dataset(self) -> typing.Callable[ - [dataset_service.DeleteDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_dataset( + self, + ) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def import_data(self) -> typing.Callable[ - [dataset_service.ImportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def import_data( + self, + ) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_data(self) -> typing.Callable[ - [dataset_service.ExportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_data( + self, + ) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def list_data_items(self) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], - typing.Union[ - dataset_service.ListDataItemsResponse, - typing.Awaitable[dataset_service.ListDataItemsResponse] - ]]: + def list_data_items( + self, + ) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse], + ], + ]: raise NotImplementedError() @property - def get_annotation_spec(self) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], - typing.Union[ - annotation_spec.AnnotationSpec, - typing.Awaitable[annotation_spec.AnnotationSpec] - ]]: + def get_annotation_spec( + self, + ) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec], + ], + ]: raise NotImplementedError() @property - def list_annotations(self) -> typing.Callable[ - [dataset_service.ListAnnotationsRequest], - typing.Union[ - dataset_service.ListAnnotationsResponse, - typing.Awaitable[dataset_service.ListAnnotationsResponse] - ]]: + def list_annotations( + self, + ) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'DatasetServiceTransport', -) +__all__ = ("DatasetServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py index 96efd8e427..65bd8baf79 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -46,21 +46,24 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -172,13 +175,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -211,7 +216,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -229,17 +234,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_dataset(self) -> Callable[ - [dataset_service.CreateDatasetRequest], - operations.Operation]: + def create_dataset( + self, + ) -> Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -254,18 +257,18 @@ def create_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_dataset' not in self._stubs: - self._stubs['create_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/CreateDataset', + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/CreateDataset", request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_dataset'] + return self._stubs["create_dataset"] @property - def get_dataset(self) -> Callable[ - [dataset_service.GetDatasetRequest], - dataset.Dataset]: + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -280,18 +283,18 @@ def get_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_dataset' not in self._stubs: - self._stubs['get_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/GetDataset', + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetDataset", request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs['get_dataset'] + return self._stubs["get_dataset"] @property - def update_dataset(self) -> Callable[ - [dataset_service.UpdateDatasetRequest], - gca_dataset.Dataset]: + def update_dataset( + self, + ) -> Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -306,18 +309,20 @@ def update_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_dataset' not in self._stubs: - self._stubs['update_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/UpdateDataset', + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/UpdateDataset", request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs['update_dataset'] + return self._stubs["update_dataset"] @property - def list_datasets(self) -> Callable[ - [dataset_service.ListDatasetsRequest], - dataset_service.ListDatasetsResponse]: + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse + ]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -332,18 +337,18 @@ def list_datasets(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_datasets' not in self._stubs: - self._stubs['list_datasets'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ListDatasets', + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDatasets", request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs['list_datasets'] + return self._stubs["list_datasets"] @property - def delete_dataset(self) -> Callable[ - [dataset_service.DeleteDatasetRequest], - operations.Operation]: + def delete_dataset( + self, + ) -> Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -358,18 +363,18 @@ def delete_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_dataset' not in self._stubs: - self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/DeleteDataset', + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/DeleteDataset", request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_dataset'] + return self._stubs["delete_dataset"] @property - def import_data(self) -> Callable[ - [dataset_service.ImportDataRequest], - operations.Operation]: + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], operations.Operation]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -384,18 +389,18 @@ def import_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_data' not in self._stubs: - self._stubs['import_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ImportData', + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ImportData", request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_data'] + return self._stubs["import_data"] @property - def export_data(self) -> Callable[ - [dataset_service.ExportDataRequest], - operations.Operation]: + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], operations.Operation]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -410,18 +415,20 @@ def export_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_data' not in self._stubs: - self._stubs['export_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ExportData', + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ExportData", request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_data'] + return self._stubs["export_data"] @property - def list_data_items(self) -> Callable[ - [dataset_service.ListDataItemsRequest], - dataset_service.ListDataItemsResponse]: + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse + ]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -436,18 +443,20 @@ def list_data_items(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_items' not in self._stubs: - self._stubs['list_data_items'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ListDataItems', + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDataItems", request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs['list_data_items'] + return self._stubs["list_data_items"] @property - def get_annotation_spec(self) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - annotation_spec.AnnotationSpec]: + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec + ]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -462,18 +471,21 @@ def get_annotation_spec(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_annotation_spec' not in self._stubs: - self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec', + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec", request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs['get_annotation_spec'] + return self._stubs["get_annotation_spec"] @property - def list_annotations(self) -> Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse]: + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse, + ]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -488,15 +500,13 @@ def list_annotations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_annotations' not in self._stubs: - self._stubs['list_annotations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ListAnnotations', + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListAnnotations", request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs['list_annotations'] + return self._stubs["list_annotations"] -__all__ = ( - 'DatasetServiceGrpcTransport', -) +__all__ = ("DatasetServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py index 924299a2f7..90d4dc67f2 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import annotation_spec @@ -53,13 +53,15 @@ class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -88,22 +90,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -242,9 +246,11 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_dataset(self) -> Callable[ - [dataset_service.CreateDatasetRequest], - Awaitable[operations.Operation]]: + def create_dataset( + self, + ) -> Callable[ + [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -259,18 +265,18 @@ def create_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_dataset' not in self._stubs: - self._stubs['create_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/CreateDataset', + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/CreateDataset", request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_dataset'] + return self._stubs["create_dataset"] @property - def get_dataset(self) -> Callable[ - [dataset_service.GetDatasetRequest], - Awaitable[dataset.Dataset]]: + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -285,18 +291,20 @@ def get_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_dataset' not in self._stubs: - self._stubs['get_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/GetDataset', + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetDataset", request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs['get_dataset'] + return self._stubs["get_dataset"] @property - def update_dataset(self) -> Callable[ - [dataset_service.UpdateDatasetRequest], - Awaitable[gca_dataset.Dataset]]: + def update_dataset( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] + ]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -311,18 +319,21 @@ def update_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_dataset' not in self._stubs: - self._stubs['update_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/UpdateDataset', + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/UpdateDataset", request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs['update_dataset'] + return self._stubs["update_dataset"] @property - def list_datasets(self) -> Callable[ - [dataset_service.ListDatasetsRequest], - Awaitable[dataset_service.ListDatasetsResponse]]: + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse], + ]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -337,18 +348,20 @@ def list_datasets(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_datasets' not in self._stubs: - self._stubs['list_datasets'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ListDatasets', + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDatasets", request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs['list_datasets'] + return self._stubs["list_datasets"] @property - def delete_dataset(self) -> Callable[ - [dataset_service.DeleteDatasetRequest], - Awaitable[operations.Operation]]: + def delete_dataset( + self, + ) -> Callable[ + [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -363,18 +376,18 @@ def delete_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_dataset' not in self._stubs: - self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/DeleteDataset', + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/DeleteDataset", request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_dataset'] + return self._stubs["delete_dataset"] @property - def import_data(self) -> Callable[ - [dataset_service.ImportDataRequest], - Awaitable[operations.Operation]]: + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -389,18 +402,18 @@ def import_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_data' not in self._stubs: - self._stubs['import_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ImportData', + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ImportData", request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_data'] + return self._stubs["import_data"] @property - def export_data(self) -> Callable[ - [dataset_service.ExportDataRequest], - Awaitable[operations.Operation]]: + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -415,18 +428,21 @@ def export_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_data' not in self._stubs: - self._stubs['export_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ExportData', + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ExportData", request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_data'] + return self._stubs["export_data"] @property - def list_data_items(self) -> Callable[ - [dataset_service.ListDataItemsRequest], - Awaitable[dataset_service.ListDataItemsResponse]]: + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse], + ]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -441,18 +457,21 @@ def list_data_items(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_items' not in self._stubs: - self._stubs['list_data_items'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ListDataItems', + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDataItems", request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs['list_data_items'] + return self._stubs["list_data_items"] @property - def get_annotation_spec(self) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - Awaitable[annotation_spec.AnnotationSpec]]: + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec], + ]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -467,18 +486,21 @@ def get_annotation_spec(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_annotation_spec' not in self._stubs: - self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec', + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec", request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs['get_annotation_spec'] + return self._stubs["get_annotation_spec"] @property - def list_annotations(self) -> Callable[ - [dataset_service.ListAnnotationsRequest], - Awaitable[dataset_service.ListAnnotationsResponse]]: + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse], + ]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -493,15 +515,13 @@ def list_annotations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_annotations' not in self._stubs: - self._stubs['list_annotations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.DatasetService/ListAnnotations', + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListAnnotations", request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs['list_annotations'] + return self._stubs["list_annotations"] -__all__ = ( - 'DatasetServiceGrpcAsyncIOTransport', -) +__all__ = ("DatasetServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py index e4f3dcfbcf..035a5b2388 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import EndpointServiceAsyncClient __all__ = ( - 'EndpointServiceClient', - 'EndpointServiceAsyncClient', + "EndpointServiceClient", + "EndpointServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index 244c35bcba..13f099328b 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -58,20 +58,34 @@ class EndpointServiceAsyncClient: model_path = staticmethod(EndpointServiceClient.model_path) parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) - common_billing_account_path = staticmethod(EndpointServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(EndpointServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + EndpointServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + EndpointServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(EndpointServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + EndpointServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(EndpointServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(EndpointServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + EndpointServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + EndpointServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(EndpointServiceClient.common_project_path) - parse_common_project_path = staticmethod(EndpointServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + EndpointServiceClient.parse_common_project_path + ) common_location_path = staticmethod(EndpointServiceClient.common_location_path) - parse_common_location_path = staticmethod(EndpointServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + EndpointServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -114,14 +128,18 @@ def transport(self) -> EndpointServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient)) + get_transport_class = functools.partial( + type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -160,18 +178,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates an Endpoint. Args: @@ -211,8 +229,10 @@ async def create_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.CreateEndpointRequest(request) @@ -235,18 +255,11 @@ async def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -259,14 +272,15 @@ async def create_endpoint(self, # Done; return the response. return response - async def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + async def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -299,8 +313,10 @@ async def get_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.GetEndpointRequest(request) @@ -321,30 +337,24 @@ async def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsAsyncPager: + async def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: r"""Lists Endpoints in a Location. Args: @@ -380,8 +390,10 @@ async def list_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.ListEndpointsRequest(request) @@ -402,40 +414,31 @@ async def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListEndpointsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + async def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -475,8 +478,10 @@ async def update_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.UpdateEndpointRequest(request) @@ -499,30 +504,26 @@ async def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes an Endpoint. Args: @@ -568,8 +569,10 @@ async def delete_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.DeleteEndpointRequest(request) @@ -590,18 +593,11 @@ async def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -614,16 +610,19 @@ async def delete_endpoint(self, # Done; return the response. return response - async def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -692,8 +691,10 @@ async def deploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.DeployModelRequest(request) @@ -719,18 +720,11 @@ async def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -743,16 +737,19 @@ async def deploy_model(self, # Done; return the response. return response - async def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -812,8 +809,10 @@ async def undeploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.UndeployModelRequest(request) @@ -839,18 +838,11 @@ async def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -864,21 +856,14 @@ async def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceAsyncClient', -) +__all__ = ("EndpointServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 3b78f5902e..de54b0b9b5 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -56,13 +56,14 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry['grpc'] = EndpointServiceGrpcTransport - _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[EndpointServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry["grpc"] = EndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -113,7 +114,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -148,9 +149,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: EndpointServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -165,88 +165,104 @@ def transport(self) -> EndpointServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -290,7 +306,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -300,7 +318,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -312,7 +332,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -324,8 +346,10 @@ def __init__(self, *, if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -344,15 +368,16 @@ def __init__(self, *, client_info=client_info, ) - def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates an Endpoint. Args: @@ -392,8 +417,10 @@ def create_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.CreateEndpointRequest. @@ -417,18 +444,11 @@ def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -441,14 +461,15 @@ def create_endpoint(self, # Done; return the response. return response - def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -481,8 +502,10 @@ def get_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.GetEndpointRequest. @@ -504,30 +527,24 @@ def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -563,8 +580,10 @@ def list_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.ListEndpointsRequest. @@ -586,40 +605,31 @@ def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -659,8 +669,10 @@ def update_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UpdateEndpointRequest. @@ -684,30 +696,26 @@ def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes an Endpoint. Args: @@ -753,8 +761,10 @@ def delete_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeleteEndpointRequest. @@ -776,18 +786,11 @@ def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -800,16 +803,19 @@ def delete_endpoint(self, # Done; return the response. return response - def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -878,8 +884,10 @@ def deploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeployModelRequest. @@ -905,18 +913,11 @@ def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -929,16 +930,19 @@ def deploy_model(self, # Done; return the response. return response - def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -998,8 +1002,10 @@ def undeploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UndeployModelRequest. @@ -1025,18 +1031,11 @@ def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1050,21 +1049,14 @@ def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceClient', -) +__all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py index 154c455826..c22df91c8c 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import endpoint from google.cloud.aiplatform_v1.types import endpoint_service @@ -38,12 +47,15 @@ class ListEndpointsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., endpoint_service.ListEndpointsResponse], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., endpoint_service.ListEndpointsResponse], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: yield from page.endpoints def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListEndpointsAsyncPager: @@ -97,12 +109,15 @@ class ListEndpointsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -140,4 +155,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py index eb2ef767fe..3d0695461d 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] -_transport_registry['grpc'] = EndpointServiceGrpcTransport -_transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = EndpointServiceGrpcTransport +_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport __all__ = ( - 'EndpointServiceTransport', - 'EndpointServiceGrpcTransport', - 'EndpointServiceGrpcAsyncIOTransport', + "EndpointServiceTransport", + "EndpointServiceGrpcTransport", + "EndpointServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py index 43520356ad..8dfc7282c3 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,29 +35,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -80,8 +80,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -90,17 +90,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -109,41 +111,26 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_endpoint: gapic_v1.method.wrap_method( - self.create_endpoint, - default_timeout=None, - client_info=client_info, + self.create_endpoint, default_timeout=None, client_info=client_info, ), self.get_endpoint: gapic_v1.method.wrap_method( - self.get_endpoint, - default_timeout=None, - client_info=client_info, + self.get_endpoint, default_timeout=None, client_info=client_info, ), self.list_endpoints: gapic_v1.method.wrap_method( - self.list_endpoints, - default_timeout=None, - client_info=client_info, + self.list_endpoints, default_timeout=None, client_info=client_info, ), self.update_endpoint: gapic_v1.method.wrap_method( - self.update_endpoint, - default_timeout=None, - client_info=client_info, + self.update_endpoint, default_timeout=None, client_info=client_info, ), self.delete_endpoint: gapic_v1.method.wrap_method( - self.delete_endpoint, - default_timeout=None, - client_info=client_info, + self.delete_endpoint, default_timeout=None, client_info=client_info, ), self.deploy_model: gapic_v1.method.wrap_method( - self.deploy_model, - default_timeout=None, - client_info=client_info, + self.deploy_model, default_timeout=None, client_info=client_info, ), self.undeploy_model: gapic_v1.method.wrap_method( - self.undeploy_model, - default_timeout=None, - client_info=client_info, + self.undeploy_model, default_timeout=None, client_info=client_info, ), - } @property @@ -152,69 +139,70 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_endpoint(self) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_endpoint(self) -> typing.Callable[ - [endpoint_service.GetEndpointRequest], - typing.Union[ - endpoint.Endpoint, - typing.Awaitable[endpoint.Endpoint] - ]]: + def get_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def list_endpoints(self) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], - typing.Union[ - endpoint_service.ListEndpointsResponse, - typing.Awaitable[endpoint_service.ListEndpointsResponse] - ]]: + def list_endpoints( + self, + ) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse], + ], + ]: raise NotImplementedError() @property - def update_endpoint(self) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], - typing.Union[ - gca_endpoint.Endpoint, - typing.Awaitable[gca_endpoint.Endpoint] - ]]: + def update_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def delete_endpoint(self) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def deploy_model(self) -> typing.Callable[ - [endpoint_service.DeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def deploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def undeploy_model(self) -> typing.Callable[ - [endpoint_service.UndeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def undeploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'EndpointServiceTransport', -) +__all__ = ("EndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py index 448aa173b9..8a2c837161 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -45,21 +45,24 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -171,13 +174,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -210,7 +215,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -228,17 +233,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_endpoint(self) -> Callable[ - [endpoint_service.CreateEndpointRequest], - operations.Operation]: + def create_endpoint( + self, + ) -> Callable[[endpoint_service.CreateEndpointRequest], operations.Operation]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -253,18 +256,18 @@ def create_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_endpoint' not in self._stubs: - self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint', + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint", request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_endpoint'] + return self._stubs["create_endpoint"] @property - def get_endpoint(self) -> Callable[ - [endpoint_service.GetEndpointRequest], - endpoint.Endpoint]: + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -279,18 +282,20 @@ def get_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_endpoint' not in self._stubs: - self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/GetEndpoint', + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/GetEndpoint", request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs['get_endpoint'] + return self._stubs["get_endpoint"] @property - def list_endpoints(self) -> Callable[ - [endpoint_service.ListEndpointsRequest], - endpoint_service.ListEndpointsResponse]: + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse + ]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -305,18 +310,18 @@ def list_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_endpoints' not in self._stubs: - self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/ListEndpoints', + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/ListEndpoints", request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs['list_endpoints'] + return self._stubs["list_endpoints"] @property - def update_endpoint(self) -> Callable[ - [endpoint_service.UpdateEndpointRequest], - gca_endpoint.Endpoint]: + def update_endpoint( + self, + ) -> Callable[[endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -331,18 +336,18 @@ def update_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_endpoint' not in self._stubs: - self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint', + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint", request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs['update_endpoint'] + return self._stubs["update_endpoint"] @property - def delete_endpoint(self) -> Callable[ - [endpoint_service.DeleteEndpointRequest], - operations.Operation]: + def delete_endpoint( + self, + ) -> Callable[[endpoint_service.DeleteEndpointRequest], operations.Operation]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -357,18 +362,18 @@ def delete_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_endpoint' not in self._stubs: - self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint', + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint", request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_endpoint'] + return self._stubs["delete_endpoint"] @property - def deploy_model(self) -> Callable[ - [endpoint_service.DeployModelRequest], - operations.Operation]: + def deploy_model( + self, + ) -> Callable[[endpoint_service.DeployModelRequest], operations.Operation]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -384,18 +389,18 @@ def deploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_model' not in self._stubs: - self._stubs['deploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/DeployModel', + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeployModel", request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_model'] + return self._stubs["deploy_model"] @property - def undeploy_model(self) -> Callable[ - [endpoint_service.UndeployModelRequest], - operations.Operation]: + def undeploy_model( + self, + ) -> Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -412,15 +417,13 @@ def undeploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_model' not in self._stubs: - self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/UndeployModel', + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UndeployModel", request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_model'] + return self._stubs["undeploy_model"] -__all__ = ( - 'EndpointServiceGrpcTransport', -) +__all__ = ("EndpointServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py index 14e2735edd..d10160a493 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import endpoint @@ -52,13 +52,15 @@ class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -87,22 +89,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -241,9 +245,11 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_endpoint(self) -> Callable[ - [endpoint_service.CreateEndpointRequest], - Awaitable[operations.Operation]]: + def create_endpoint( + self, + ) -> Callable[ + [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -258,18 +264,18 @@ def create_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_endpoint' not in self._stubs: - self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint', + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint", request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_endpoint'] + return self._stubs["create_endpoint"] @property - def get_endpoint(self) -> Callable[ - [endpoint_service.GetEndpointRequest], - Awaitable[endpoint.Endpoint]]: + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -284,18 +290,21 @@ def get_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_endpoint' not in self._stubs: - self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/GetEndpoint', + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/GetEndpoint", request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs['get_endpoint'] + return self._stubs["get_endpoint"] @property - def list_endpoints(self) -> Callable[ - [endpoint_service.ListEndpointsRequest], - Awaitable[endpoint_service.ListEndpointsResponse]]: + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse], + ]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -310,18 +319,20 @@ def list_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_endpoints' not in self._stubs: - self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/ListEndpoints', + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/ListEndpoints", request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs['list_endpoints'] + return self._stubs["list_endpoints"] @property - def update_endpoint(self) -> Callable[ - [endpoint_service.UpdateEndpointRequest], - Awaitable[gca_endpoint.Endpoint]]: + def update_endpoint( + self, + ) -> Callable[ + [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] + ]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -336,18 +347,20 @@ def update_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_endpoint' not in self._stubs: - self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint', + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint", request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs['update_endpoint'] + return self._stubs["update_endpoint"] @property - def delete_endpoint(self) -> Callable[ - [endpoint_service.DeleteEndpointRequest], - Awaitable[operations.Operation]]: + def delete_endpoint( + self, + ) -> Callable[ + [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -362,18 +375,20 @@ def delete_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_endpoint' not in self._stubs: - self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint', + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint", request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_endpoint'] + return self._stubs["delete_endpoint"] @property - def deploy_model(self) -> Callable[ - [endpoint_service.DeployModelRequest], - Awaitable[operations.Operation]]: + def deploy_model( + self, + ) -> Callable[ + [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -389,18 +404,20 @@ def deploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_model' not in self._stubs: - self._stubs['deploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/DeployModel', + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeployModel", request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_model'] + return self._stubs["deploy_model"] @property - def undeploy_model(self) -> Callable[ - [endpoint_service.UndeployModelRequest], - Awaitable[operations.Operation]]: + def undeploy_model( + self, + ) -> Callable[ + [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -417,15 +434,13 @@ def undeploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_model' not in self._stubs: - self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.EndpointService/UndeployModel', + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UndeployModel", request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_model'] + return self._stubs["undeploy_model"] -__all__ = ( - 'EndpointServiceGrpcAsyncIOTransport', -) +__all__ = ("EndpointServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/job_service/__init__.py b/google/cloud/aiplatform_v1/services/job_service/__init__.py index 037407b714..5f157047f5 100644 --- a/google/cloud/aiplatform_v1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/job_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import JobServiceAsyncClient __all__ = ( - 'JobServiceClient', - 'JobServiceAsyncClient', + "JobServiceClient", + "JobServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index e76498a85d..e253bcc5d6 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -21,18 +21,20 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import completion_stats from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job @@ -40,7 +42,9 @@ from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import job_service from google.cloud.aiplatform_v1.types import job_state from google.cloud.aiplatform_v1.types import machine_resources @@ -67,34 +71,50 @@ class JobServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) - parse_batch_prediction_job_path = staticmethod(JobServiceClient.parse_batch_prediction_job_path) + parse_batch_prediction_job_path = staticmethod( + JobServiceClient.parse_batch_prediction_job_path + ) custom_job_path = staticmethod(JobServiceClient.custom_job_path) parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) - parse_data_labeling_job_path = staticmethod(JobServiceClient.parse_data_labeling_job_path) + parse_data_labeling_job_path = staticmethod( + JobServiceClient.parse_data_labeling_job_path + ) dataset_path = staticmethod(JobServiceClient.dataset_path) parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) - hyperparameter_tuning_job_path = staticmethod(JobServiceClient.hyperparameter_tuning_job_path) - parse_hyperparameter_tuning_job_path = staticmethod(JobServiceClient.parse_hyperparameter_tuning_job_path) + hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.hyperparameter_tuning_job_path + ) + parse_hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.parse_hyperparameter_tuning_job_path + ) model_path = staticmethod(JobServiceClient.model_path) parse_model_path = staticmethod(JobServiceClient.parse_model_path) trial_path = staticmethod(JobServiceClient.trial_path) parse_trial_path = staticmethod(JobServiceClient.parse_trial_path) - common_billing_account_path = staticmethod(JobServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(JobServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + JobServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + JobServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(JobServiceClient.common_folder_path) parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) common_organization_path = staticmethod(JobServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(JobServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + JobServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(JobServiceClient.common_project_path) parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) common_location_path = staticmethod(JobServiceClient.common_location_path) - parse_common_location_path = staticmethod(JobServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + JobServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -137,14 +157,18 @@ def transport(self) -> JobServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(JobServiceClient).get_transport_class, type(JobServiceClient)) + get_transport_class = functools.partial( + type(JobServiceClient).get_transport_class, type(JobServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -183,18 +207,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + async def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -239,8 +263,10 @@ async def create_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateCustomJobRequest(request) @@ -263,30 +289,24 @@ async def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + async def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -324,8 +344,10 @@ async def get_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetCustomJobRequest(request) @@ -346,30 +368,24 @@ async def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsAsyncPager: + async def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: r"""Lists CustomJobs in a Location. Args: @@ -405,8 +421,10 @@ async def list_custom_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListCustomJobsRequest(request) @@ -427,39 +445,30 @@ async def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListCustomJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a CustomJob. Args: @@ -505,8 +514,10 @@ async def delete_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteCustomJobRequest(request) @@ -527,18 +538,11 @@ async def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -551,14 +555,15 @@ async def delete_custom_job(self, # Done; return the response. return response - async def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -596,8 +601,10 @@ async def cancel_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelCustomJobRequest(request) @@ -618,28 +625,24 @@ async def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -679,8 +682,10 @@ async def create_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateDataLabelingJobRequest(request) @@ -703,30 +708,24 @@ async def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + async def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -760,8 +759,10 @@ async def get_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetDataLabelingJobRequest(request) @@ -782,30 +783,24 @@ async def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsAsyncPager: + async def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -840,8 +835,10 @@ async def list_data_labeling_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListDataLabelingJobsRequest(request) @@ -862,39 +859,30 @@ async def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataLabelingJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a DataLabelingJob. Args: @@ -941,8 +929,10 @@ async def delete_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteDataLabelingJobRequest(request) @@ -963,18 +953,11 @@ async def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -987,14 +970,15 @@ async def delete_data_labeling_job(self, # Done; return the response. return response - async def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1022,8 +1006,10 @@ async def cancel_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelDataLabelingJobRequest(request) @@ -1044,28 +1030,24 @@ async def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1107,8 +1089,10 @@ async def create_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateHyperparameterTuningJobRequest(request) @@ -1131,30 +1115,24 @@ async def create_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + async def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1190,8 +1168,10 @@ async def get_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetHyperparameterTuningJobRequest(request) @@ -1212,30 +1192,24 @@ async def get_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + async def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1271,8 +1245,10 @@ async def list_hyperparameter_tuning_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListHyperparameterTuningJobsRequest(request) @@ -1293,39 +1269,30 @@ async def list_hyperparameter_tuning_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListHyperparameterTuningJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1372,8 +1339,10 @@ async def delete_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteHyperparameterTuningJobRequest(request) @@ -1394,18 +1363,11 @@ async def delete_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1418,14 +1380,15 @@ async def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - async def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1466,8 +1429,10 @@ async def cancel_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelHyperparameterTuningJobRequest(request) @@ -1488,28 +1453,24 @@ async def cancel_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1554,8 +1515,10 @@ async def create_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateBatchPredictionJobRequest(request) @@ -1578,30 +1541,24 @@ async def create_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + async def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1639,8 +1596,10 @@ async def get_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetBatchPredictionJobRequest(request) @@ -1661,30 +1620,24 @@ async def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsAsyncPager: + async def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1720,8 +1673,10 @@ async def list_batch_prediction_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListBatchPredictionJobsRequest(request) @@ -1742,39 +1697,30 @@ async def list_batch_prediction_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListBatchPredictionJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -1822,8 +1768,10 @@ async def delete_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteBatchPredictionJobRequest(request) @@ -1844,18 +1792,11 @@ async def delete_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1868,14 +1809,15 @@ async def delete_batch_prediction_job(self, # Done; return the response. return response - async def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -1914,8 +1856,10 @@ async def cancel_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelBatchPredictionJobRequest(request) @@ -1936,35 +1880,23 @@ async def cancel_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceAsyncClient', -) +__all__ = ("JobServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index 1a304de108..746ce91c4b 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -23,20 +23,22 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import completion_stats from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job @@ -44,7 +46,9 @@ from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import job_service from google.cloud.aiplatform_v1.types import job_state from google.cloud.aiplatform_v1.types import machine_resources @@ -69,13 +73,12 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry['grpc'] = JobServiceGrpcTransport - _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[JobServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -126,7 +129,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -161,9 +164,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: JobServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -178,143 +180,194 @@ def transport(self) -> JobServiceTransport: return self._transport @staticmethod - def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: + def batch_prediction_job_path( + project: str, location: str, batch_prediction_job: str, + ) -> str: """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, + location=location, + batch_prediction_job=batch_prediction_job, + ) @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: + def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: """Parse a batch_prediction_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def custom_job_path(project: str,location: str,custom_job: str,) -> str: + def custom_job_path(project: str, location: str, custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str,str]: + def parse_custom_job_path(path: str) -> Dict[str, str]: """Parse a custom_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: + def data_labeling_job_path( + project: str, location: str, data_labeling_job: str, + ) -> str: """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str,str]: + def parse_data_labeling_job_path(path: str) -> Dict[str, str]: """Parse a data_labeling_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: + def hyperparameter_tuning_job_path( + project: str, location: str, hyperparameter_tuning_job: str, + ) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str,location: str,study: str,trial: str,) -> str: + def trial_path(project: str, location: str, study: str, trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) @staticmethod - def parse_trial_path(path: str) -> Dict[str,str]: + def parse_trial_path(path: str) -> Dict[str, str]: """Parse a trial path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -358,7 +411,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -368,7 +423,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -380,7 +437,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -392,8 +451,10 @@ def __init__(self, *, if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -412,15 +473,16 @@ def __init__(self, *, client_info=client_info, ) - def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -465,8 +527,10 @@ def create_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateCustomJobRequest. @@ -490,30 +554,24 @@ def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -551,8 +609,10 @@ def get_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetCustomJobRequest. @@ -574,30 +634,24 @@ def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -633,8 +687,10 @@ def list_custom_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListCustomJobsRequest. @@ -656,39 +712,30 @@ def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a CustomJob. Args: @@ -734,8 +781,10 @@ def delete_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteCustomJobRequest. @@ -757,18 +806,11 @@ def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -781,14 +823,15 @@ def delete_custom_job(self, # Done; return the response. return response - def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -826,8 +869,10 @@ def cancel_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelCustomJobRequest. @@ -849,28 +894,24 @@ def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -910,8 +951,10 @@ def create_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateDataLabelingJobRequest. @@ -935,30 +978,24 @@ def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -992,8 +1029,10 @@ def get_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetDataLabelingJobRequest. @@ -1015,30 +1054,24 @@ def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -1073,8 +1106,10 @@ def list_data_labeling_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListDataLabelingJobsRequest. @@ -1096,39 +1131,30 @@ def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1175,8 +1201,10 @@ def delete_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteDataLabelingJobRequest. @@ -1198,18 +1226,11 @@ def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1222,14 +1243,15 @@ def delete_data_labeling_job(self, # Done; return the response. return response - def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1257,8 +1279,10 @@ def cancel_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelDataLabelingJobRequest. @@ -1280,28 +1304,24 @@ def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1343,8 +1363,10 @@ def create_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateHyperparameterTuningJobRequest. @@ -1363,35 +1385,31 @@ def create_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1427,8 +1445,10 @@ def get_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetHyperparameterTuningJobRequest. @@ -1445,35 +1465,31 @@ def get_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1509,8 +1525,10 @@ def list_hyperparameter_tuning_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListHyperparameterTuningJobsRequest. @@ -1527,44 +1545,37 @@ def list_hyperparameter_tuning_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1611,8 +1622,10 @@ def delete_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteHyperparameterTuningJobRequest. @@ -1629,23 +1642,18 @@ def delete_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1658,14 +1666,15 @@ def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1706,8 +1715,10 @@ def cancel_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelHyperparameterTuningJobRequest. @@ -1724,33 +1735,31 @@ def cancel_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1795,8 +1804,10 @@ def create_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateBatchPredictionJobRequest. @@ -1815,35 +1826,31 @@ def create_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1881,8 +1888,10 @@ def get_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetBatchPredictionJobRequest. @@ -1904,30 +1913,24 @@ def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1963,8 +1966,10 @@ def list_batch_prediction_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListBatchPredictionJobsRequest. @@ -1981,44 +1986,37 @@ def list_batch_prediction_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2066,8 +2064,10 @@ def delete_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteBatchPredictionJobRequest. @@ -2084,23 +2084,18 @@ def delete_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -2113,14 +2108,15 @@ def delete_batch_prediction_job(self, # Done; return the response. return response - def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -2159,8 +2155,10 @@ def cancel_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelBatchPredictionJobRequest. @@ -2177,40 +2175,30 @@ def cancel_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceClient', -) +__all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/job_service/pagers.py b/google/cloud/aiplatform_v1/services/job_service/pagers.py index dfc5e30105..35d679b6ad 100644 --- a/google/cloud/aiplatform_v1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/job_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job @@ -41,12 +50,15 @@ class ListCustomJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListCustomJobsResponse], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListCustomJobsResponse], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -80,7 +92,7 @@ def __iter__(self) -> Iterable[custom_job.CustomJob]: yield from page.custom_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListCustomJobsAsyncPager: @@ -100,12 +112,15 @@ class ListCustomJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -143,7 +158,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataLabelingJobsPager: @@ -163,12 +178,15 @@ class ListDataLabelingJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListDataLabelingJobsResponse], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListDataLabelingJobsResponse], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -202,7 +220,7 @@ def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: yield from page.data_labeling_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataLabelingJobsAsyncPager: @@ -222,12 +240,15 @@ class ListDataLabelingJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -265,7 +286,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsPager: @@ -285,12 +306,15 @@ class ListHyperparameterTuningJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -324,7 +348,7 @@ def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob yield from page.hyperparameter_tuning_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsAsyncPager: @@ -344,12 +368,17 @@ class ListHyperparameterTuningJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListHyperparameterTuningJobsResponse]], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -371,14 +400,18 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + async def pages( + self, + ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__(self) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + def __aiter__( + self, + ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: async def async_generator(): async for page in self.pages: for response in page.hyperparameter_tuning_jobs: @@ -387,7 +420,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListBatchPredictionJobsPager: @@ -407,12 +440,15 @@ class ListBatchPredictionJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListBatchPredictionJobsResponse], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListBatchPredictionJobsResponse], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -446,7 +482,7 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: yield from page.batch_prediction_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListBatchPredictionJobsAsyncPager: @@ -466,12 +502,15 @@ class ListBatchPredictionJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -509,4 +548,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py index 8b5de46a7e..349bfbcdea 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] -_transport_registry['grpc'] = JobServiceGrpcTransport -_transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = JobServiceGrpcTransport +_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport __all__ = ( - 'JobServiceTransport', - 'JobServiceGrpcTransport', - 'JobServiceGrpcAsyncIOTransport', + "JobServiceTransport", + "JobServiceGrpcTransport", + "JobServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1/services/job_service/transports/base.py index f3ee6dc74a..7151ba3611 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/base.py @@ -21,19 +21,23 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -42,29 +46,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -87,8 +91,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -97,17 +101,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -116,29 +122,19 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_custom_job: gapic_v1.method.wrap_method( - self.create_custom_job, - default_timeout=None, - client_info=client_info, + self.create_custom_job, default_timeout=None, client_info=client_info, ), self.get_custom_job: gapic_v1.method.wrap_method( - self.get_custom_job, - default_timeout=None, - client_info=client_info, + self.get_custom_job, default_timeout=None, client_info=client_info, ), self.list_custom_jobs: gapic_v1.method.wrap_method( - self.list_custom_jobs, - default_timeout=None, - client_info=client_info, + self.list_custom_jobs, default_timeout=None, client_info=client_info, ), self.delete_custom_job: gapic_v1.method.wrap_method( - self.delete_custom_job, - default_timeout=None, - client_info=client_info, + self.delete_custom_job, default_timeout=None, client_info=client_info, ), self.cancel_custom_job: gapic_v1.method.wrap_method( - self.cancel_custom_job, - default_timeout=None, - client_info=client_info, + self.cancel_custom_job, default_timeout=None, client_info=client_info, ), self.create_data_labeling_job: gapic_v1.method.wrap_method( self.create_data_labeling_job, @@ -215,7 +211,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -224,186 +219,216 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_custom_job(self) -> typing.Callable[ - [job_service.CreateCustomJobRequest], - typing.Union[ - gca_custom_job.CustomJob, - typing.Awaitable[gca_custom_job.CustomJob] - ]]: + def create_custom_job( + self, + ) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] + ], + ]: raise NotImplementedError() @property - def get_custom_job(self) -> typing.Callable[ - [job_service.GetCustomJobRequest], - typing.Union[ - custom_job.CustomJob, - typing.Awaitable[custom_job.CustomJob] - ]]: + def get_custom_job( + self, + ) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], + ]: raise NotImplementedError() @property - def list_custom_jobs(self) -> typing.Callable[ - [job_service.ListCustomJobsRequest], - typing.Union[ - job_service.ListCustomJobsResponse, - typing.Awaitable[job_service.ListCustomJobsResponse] - ]]: + def list_custom_jobs( + self, + ) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_custom_job(self) -> typing.Callable[ - [job_service.DeleteCustomJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_custom_job( + self, + ) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_custom_job(self) -> typing.Callable[ - [job_service.CancelCustomJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_custom_job( + self, + ) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_data_labeling_job(self) -> typing.Callable[ - [job_service.CreateDataLabelingJobRequest], - typing.Union[ - gca_data_labeling_job.DataLabelingJob, - typing.Awaitable[gca_data_labeling_job.DataLabelingJob] - ]]: + def create_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def get_data_labeling_job(self) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], - typing.Union[ - data_labeling_job.DataLabelingJob, - typing.Awaitable[data_labeling_job.DataLabelingJob] - ]]: + def get_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def list_data_labeling_jobs(self) -> typing.Callable[ - [job_service.ListDataLabelingJobsRequest], - typing.Union[ - job_service.ListDataLabelingJobsResponse, - typing.Awaitable[job_service.ListDataLabelingJobsResponse] - ]]: + def list_data_labeling_jobs( + self, + ) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_data_labeling_job(self) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_data_labeling_job(self) -> typing.Callable[ - [job_service.CancelDataLabelingJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - typing.Union[ - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def create_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def get_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.GetHyperparameterTuningJobRequest], - typing.Union[ - hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def get_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def list_hyperparameter_tuning_jobs(self) -> typing.Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - typing.Union[ - job_service.ListHyperparameterTuningJobsResponse, - typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse] - ]]: + def list_hyperparameter_tuning_jobs( + self, + ) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_batch_prediction_job(self) -> typing.Callable[ - [job_service.CreateBatchPredictionJobRequest], - typing.Union[ - gca_batch_prediction_job.BatchPredictionJob, - typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob] - ]]: + def create_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def get_batch_prediction_job(self) -> typing.Callable[ - [job_service.GetBatchPredictionJobRequest], - typing.Union[ - batch_prediction_job.BatchPredictionJob, - typing.Awaitable[batch_prediction_job.BatchPredictionJob] - ]]: + def get_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def list_batch_prediction_jobs(self) -> typing.Callable[ - [job_service.ListBatchPredictionJobsRequest], - typing.Union[ - job_service.ListBatchPredictionJobsResponse, - typing.Awaitable[job_service.ListBatchPredictionJobsResponse] - ]]: + def list_batch_prediction_jobs( + self, + ) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_batch_prediction_job(self) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_batch_prediction_job(self) -> typing.Callable[ - [job_service.CancelBatchPredictionJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() -__all__ = ( - 'JobServiceTransport', -) +__all__ = ("JobServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py index 9a88545dd8..f01883829b 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py @@ -18,23 +18,27 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -54,21 +58,24 @@ class JobServiceGrpcTransport(JobServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -180,13 +187,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -219,7 +228,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -237,17 +246,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_custom_job(self) -> Callable[ - [job_service.CreateCustomJobRequest], - gca_custom_job.CustomJob]: + def create_custom_job( + self, + ) -> Callable[[job_service.CreateCustomJobRequest], gca_custom_job.CustomJob]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -263,18 +270,18 @@ def create_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_custom_job' not in self._stubs: - self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateCustomJob', + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateCustomJob", request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs['create_custom_job'] + return self._stubs["create_custom_job"] @property - def get_custom_job(self) -> Callable[ - [job_service.GetCustomJobRequest], - custom_job.CustomJob]: + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -289,18 +296,20 @@ def get_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_custom_job' not in self._stubs: - self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetCustomJob', + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetCustomJob", request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs['get_custom_job'] + return self._stubs["get_custom_job"] @property - def list_custom_jobs(self) -> Callable[ - [job_service.ListCustomJobsRequest], - job_service.ListCustomJobsResponse]: + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse + ]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -315,18 +324,18 @@ def list_custom_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_custom_jobs' not in self._stubs: - self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListCustomJobs', + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListCustomJobs", request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs['list_custom_jobs'] + return self._stubs["list_custom_jobs"] @property - def delete_custom_job(self) -> Callable[ - [job_service.DeleteCustomJobRequest], - operations.Operation]: + def delete_custom_job( + self, + ) -> Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -341,18 +350,18 @@ def delete_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_custom_job' not in self._stubs: - self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteCustomJob', + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteCustomJob", request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_custom_job'] + return self._stubs["delete_custom_job"] @property - def cancel_custom_job(self) -> Callable[ - [job_service.CancelCustomJobRequest], - empty.Empty]: + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], empty.Empty]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -379,18 +388,21 @@ def cancel_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_custom_job' not in self._stubs: - self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelCustomJob', + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelCustomJob", request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_custom_job'] + return self._stubs["cancel_custom_job"] @property - def create_data_labeling_job(self) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob]: + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + gca_data_labeling_job.DataLabelingJob, + ]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -405,18 +417,20 @@ def create_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_data_labeling_job' not in self._stubs: - self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob', + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob", request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['create_data_labeling_job'] + return self._stubs["create_data_labeling_job"] @property - def get_data_labeling_job(self) -> Callable[ - [job_service.GetDataLabelingJobRequest], - data_labeling_job.DataLabelingJob]: + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob + ]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -431,18 +445,21 @@ def get_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_data_labeling_job' not in self._stubs: - self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob', + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob", request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['get_data_labeling_job'] + return self._stubs["get_data_labeling_job"] @property - def list_data_labeling_jobs(self) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse]: + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse, + ]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -457,18 +474,18 @@ def list_data_labeling_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_labeling_jobs' not in self._stubs: - self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs', + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs", request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs['list_data_labeling_jobs'] + return self._stubs["list_data_labeling_jobs"] @property - def delete_data_labeling_job(self) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], - operations.Operation]: + def delete_data_labeling_job( + self, + ) -> Callable[[job_service.DeleteDataLabelingJobRequest], operations.Operation]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -483,18 +500,18 @@ def delete_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_data_labeling_job' not in self._stubs: - self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob', + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob", request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_data_labeling_job'] + return self._stubs["delete_data_labeling_job"] @property - def cancel_data_labeling_job(self) -> Callable[ - [job_service.CancelDataLabelingJobRequest], - empty.Empty]: + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -510,18 +527,21 @@ def cancel_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_data_labeling_job' not in self._stubs: - self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob', + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob", request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_data_labeling_job'] + return self._stubs["cancel_data_labeling_job"] @property - def create_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob]: + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + ]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -537,18 +557,23 @@ def create_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_hyperparameter_tuning_job' not in self._stubs: - self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob', + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob", request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['create_hyperparameter_tuning_job'] + return self._stubs["create_hyperparameter_tuning_job"] @property - def get_hyperparameter_tuning_job(self) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob]: + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + hyperparameter_tuning_job.HyperparameterTuningJob, + ]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -563,18 +588,23 @@ def get_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_hyperparameter_tuning_job' not in self._stubs: - self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob', + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob", request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['get_hyperparameter_tuning_job'] + return self._stubs["get_hyperparameter_tuning_job"] @property - def list_hyperparameter_tuning_jobs(self) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse]: + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse, + ]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -590,18 +620,22 @@ def list_hyperparameter_tuning_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_hyperparameter_tuning_jobs' not in self._stubs: - self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs', + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs", request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs['list_hyperparameter_tuning_jobs'] + return self._stubs["list_hyperparameter_tuning_jobs"] @property - def delete_hyperparameter_tuning_job(self) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - operations.Operation]: + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation + ]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -617,18 +651,20 @@ def delete_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_hyperparameter_tuning_job' not in self._stubs: - self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob', + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob", request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_hyperparameter_tuning_job'] + return self._stubs["delete_hyperparameter_tuning_job"] @property - def cancel_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - empty.Empty]: + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[[job_service.CancelHyperparameterTuningJobRequest], empty.Empty]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -657,18 +693,23 @@ def cancel_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_hyperparameter_tuning_job' not in self._stubs: - self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob', + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob", request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_hyperparameter_tuning_job'] + return self._stubs["cancel_hyperparameter_tuning_job"] @property - def create_batch_prediction_job(self) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob]: + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + gca_batch_prediction_job.BatchPredictionJob, + ]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -684,18 +725,21 @@ def create_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_batch_prediction_job' not in self._stubs: - self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob', + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob", request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['create_batch_prediction_job'] + return self._stubs["create_batch_prediction_job"] @property - def get_batch_prediction_job(self) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob]: + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + batch_prediction_job.BatchPredictionJob, + ]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -710,18 +754,21 @@ def get_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_batch_prediction_job' not in self._stubs: - self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob', + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob", request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['get_batch_prediction_job'] + return self._stubs["get_batch_prediction_job"] @property - def list_batch_prediction_jobs(self) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse]: + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse, + ]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -736,18 +783,18 @@ def list_batch_prediction_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_batch_prediction_jobs' not in self._stubs: - self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs', + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs", request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs['list_batch_prediction_jobs'] + return self._stubs["list_batch_prediction_jobs"] @property - def delete_batch_prediction_job(self) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], - operations.Operation]: + def delete_batch_prediction_job( + self, + ) -> Callable[[job_service.DeleteBatchPredictionJobRequest], operations.Operation]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -763,18 +810,18 @@ def delete_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_batch_prediction_job' not in self._stubs: - self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob', + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob", request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_batch_prediction_job'] + return self._stubs["delete_batch_prediction_job"] @property - def cancel_batch_prediction_job(self) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], - empty.Empty]: + def cancel_batch_prediction_job( + self, + ) -> Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -800,15 +847,13 @@ def cancel_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_batch_prediction_job' not in self._stubs: - self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob', + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob", request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_batch_prediction_job'] + return self._stubs["cancel_batch_prediction_job"] -__all__ = ( - 'JobServiceGrpcTransport', -) +__all__ = ("JobServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py index 2ce9fb52e0..c12b584256 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py @@ -18,24 +18,28 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -61,13 +65,15 @@ class JobServiceGrpcAsyncIOTransport(JobServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -96,22 +102,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -250,9 +258,11 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_custom_job(self) -> Callable[ - [job_service.CreateCustomJobRequest], - Awaitable[gca_custom_job.CustomJob]]: + def create_custom_job( + self, + ) -> Callable[ + [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] + ]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -268,18 +278,18 @@ def create_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_custom_job' not in self._stubs: - self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateCustomJob', + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateCustomJob", request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs['create_custom_job'] + return self._stubs["create_custom_job"] @property - def get_custom_job(self) -> Callable[ - [job_service.GetCustomJobRequest], - Awaitable[custom_job.CustomJob]]: + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -294,18 +304,21 @@ def get_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_custom_job' not in self._stubs: - self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetCustomJob', + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetCustomJob", request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs['get_custom_job'] + return self._stubs["get_custom_job"] @property - def list_custom_jobs(self) -> Callable[ - [job_service.ListCustomJobsRequest], - Awaitable[job_service.ListCustomJobsResponse]]: + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse], + ]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -320,18 +333,20 @@ def list_custom_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_custom_jobs' not in self._stubs: - self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListCustomJobs', + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListCustomJobs", request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs['list_custom_jobs'] + return self._stubs["list_custom_jobs"] @property - def delete_custom_job(self) -> Callable[ - [job_service.DeleteCustomJobRequest], - Awaitable[operations.Operation]]: + def delete_custom_job( + self, + ) -> Callable[ + [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -346,18 +361,18 @@ def delete_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_custom_job' not in self._stubs: - self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteCustomJob', + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteCustomJob", request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_custom_job'] + return self._stubs["delete_custom_job"] @property - def cancel_custom_job(self) -> Callable[ - [job_service.CancelCustomJobRequest], - Awaitable[empty.Empty]]: + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -384,18 +399,21 @@ def cancel_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_custom_job' not in self._stubs: - self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelCustomJob', + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelCustomJob", request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_custom_job'] + return self._stubs["cancel_custom_job"] @property - def create_data_labeling_job(self) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - Awaitable[gca_data_labeling_job.DataLabelingJob]]: + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob], + ]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -410,18 +428,21 @@ def create_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_data_labeling_job' not in self._stubs: - self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob', + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob", request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['create_data_labeling_job'] + return self._stubs["create_data_labeling_job"] @property - def get_data_labeling_job(self) -> Callable[ - [job_service.GetDataLabelingJobRequest], - Awaitable[data_labeling_job.DataLabelingJob]]: + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob], + ]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -436,18 +457,21 @@ def get_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_data_labeling_job' not in self._stubs: - self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob', + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob", request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['get_data_labeling_job'] + return self._stubs["get_data_labeling_job"] @property - def list_data_labeling_jobs(self) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - Awaitable[job_service.ListDataLabelingJobsResponse]]: + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse], + ]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -462,18 +486,20 @@ def list_data_labeling_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_labeling_jobs' not in self._stubs: - self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs', + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs", request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs['list_data_labeling_jobs'] + return self._stubs["list_data_labeling_jobs"] @property - def delete_data_labeling_job(self) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], - Awaitable[operations.Operation]]: + def delete_data_labeling_job( + self, + ) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -488,18 +514,18 @@ def delete_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_data_labeling_job' not in self._stubs: - self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob', + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob", request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_data_labeling_job'] + return self._stubs["delete_data_labeling_job"] @property - def cancel_data_labeling_job(self) -> Callable[ - [job_service.CancelDataLabelingJobRequest], - Awaitable[empty.Empty]]: + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -515,18 +541,21 @@ def cancel_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_data_labeling_job' not in self._stubs: - self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob', + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob", request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_data_labeling_job'] + return self._stubs["cancel_data_labeling_job"] @property - def create_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob]]: + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -542,18 +571,23 @@ def create_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_hyperparameter_tuning_job' not in self._stubs: - self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob', + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob", request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['create_hyperparameter_tuning_job'] + return self._stubs["create_hyperparameter_tuning_job"] @property - def get_hyperparameter_tuning_job(self) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob]]: + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -568,18 +602,23 @@ def get_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_hyperparameter_tuning_job' not in self._stubs: - self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob', + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob", request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['get_hyperparameter_tuning_job'] + return self._stubs["get_hyperparameter_tuning_job"] @property - def list_hyperparameter_tuning_jobs(self) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - Awaitable[job_service.ListHyperparameterTuningJobsResponse]]: + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -595,18 +634,23 @@ def list_hyperparameter_tuning_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_hyperparameter_tuning_jobs' not in self._stubs: - self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs', + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs", request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs['list_hyperparameter_tuning_jobs'] + return self._stubs["list_hyperparameter_tuning_jobs"] @property - def delete_hyperparameter_tuning_job(self) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - Awaitable[operations.Operation]]: + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -622,18 +666,22 @@ def delete_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_hyperparameter_tuning_job' not in self._stubs: - self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob', + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob", request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_hyperparameter_tuning_job'] + return self._stubs["delete_hyperparameter_tuning_job"] @property - def cancel_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - Awaitable[empty.Empty]]: + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -662,18 +710,23 @@ def cancel_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_hyperparameter_tuning_job' not in self._stubs: - self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob', + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob", request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_hyperparameter_tuning_job'] + return self._stubs["cancel_hyperparameter_tuning_job"] @property - def create_batch_prediction_job(self) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - Awaitable[gca_batch_prediction_job.BatchPredictionJob]]: + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -689,18 +742,21 @@ def create_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_batch_prediction_job' not in self._stubs: - self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob', + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob", request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['create_batch_prediction_job'] + return self._stubs["create_batch_prediction_job"] @property - def get_batch_prediction_job(self) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - Awaitable[batch_prediction_job.BatchPredictionJob]]: + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob], + ]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -715,18 +771,21 @@ def get_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_batch_prediction_job' not in self._stubs: - self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob', + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob", request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['get_batch_prediction_job'] + return self._stubs["get_batch_prediction_job"] @property - def list_batch_prediction_jobs(self) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - Awaitable[job_service.ListBatchPredictionJobsResponse]]: + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse], + ]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -741,18 +800,20 @@ def list_batch_prediction_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_batch_prediction_jobs' not in self._stubs: - self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs', + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs", request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs['list_batch_prediction_jobs'] + return self._stubs["list_batch_prediction_jobs"] @property - def delete_batch_prediction_job(self) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], - Awaitable[operations.Operation]]: + def delete_batch_prediction_job( + self, + ) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -768,18 +829,20 @@ def delete_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_batch_prediction_job' not in self._stubs: - self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob', + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob", request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_batch_prediction_job'] + return self._stubs["delete_batch_prediction_job"] @property - def cancel_batch_prediction_job(self) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], - Awaitable[empty.Empty]]: + def cancel_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -805,15 +868,13 @@ def cancel_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_batch_prediction_job' not in self._stubs: - self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob', + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob", request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_batch_prediction_job'] + return self._stubs["cancel_batch_prediction_job"] -__all__ = ( - 'JobServiceGrpcAsyncIOTransport', -) +__all__ = ("JobServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1/services/migration_service/__init__.py index c533a12b45..1d6216d1f7 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/migration_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import MigrationServiceAsyncClient __all__ = ( - 'MigrationServiceClient', - 'MigrationServiceAsyncClient', + "MigrationServiceClient", + "MigrationServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1/services/migration_service/async_client.py index d48eb4ae0b..e7f45eeaf5 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -51,7 +51,9 @@ class MigrationServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) - parse_annotated_dataset_path = staticmethod(MigrationServiceClient.parse_annotated_dataset_path) + parse_annotated_dataset_path = staticmethod( + MigrationServiceClient.parse_annotated_dataset_path + ) dataset_path = staticmethod(MigrationServiceClient.dataset_path) parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) dataset_path = staticmethod(MigrationServiceClient.dataset_path) @@ -65,20 +67,34 @@ class MigrationServiceAsyncClient: version_path = staticmethod(MigrationServiceClient.version_path) parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) - common_billing_account_path = staticmethod(MigrationServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(MigrationServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + MigrationServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + MigrationServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(MigrationServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + MigrationServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(MigrationServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(MigrationServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + MigrationServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + MigrationServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(MigrationServiceClient.common_project_path) - parse_common_project_path = staticmethod(MigrationServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + MigrationServiceClient.parse_common_project_path + ) common_location_path = staticmethod(MigrationServiceClient.common_location_path) - parse_common_location_path = staticmethod(MigrationServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + MigrationServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -121,14 +137,18 @@ def transport(self) -> MigrationServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient)) + get_transport_class = functools.partial( + type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, MigrationServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -167,17 +187,17 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesAsyncPager: + async def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -218,8 +238,10 @@ async def search_migratable_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = migration_service.SearchMigratableResourcesRequest(request) @@ -240,40 +262,33 @@ async def search_migratable_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.SearchMigratableResourcesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -322,8 +337,10 @@ async def batch_migrate_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = migration_service.BatchMigrateResourcesRequest(request) @@ -347,18 +364,11 @@ async def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -372,21 +382,14 @@ async def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceAsyncClient', -) +__all__ = ("MigrationServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 94758701d8..042e3402d1 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -50,13 +50,14 @@ class MigrationServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] - _transport_registry['grpc'] = MigrationServiceGrpcTransport - _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[MigrationServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry["grpc"] = MigrationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: """Return an appropriate transport class. Args: @@ -110,7 +111,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -145,9 +146,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MigrationServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -162,143 +162,183 @@ def transport(self) -> MigrationServiceTransport: return self._transport @staticmethod - def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: + def annotated_dataset_path( + project: str, dataset: str, annotated_dataset: str, + ) -> str: """Return a fully-qualified annotated_dataset string.""" - return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) @staticmethod - def parse_annotated_dataset_path(path: str) -> Dict[str,str]: + def parse_annotated_dataset_path(path: str) -> Dict[str, str]: """Parse a annotated_dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def version_path(project: str,model: str,version: str,) -> str: + def version_path(project: str, model: str, version: str,) -> str: """Return a fully-qualified version string.""" - return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + return "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) @staticmethod - def parse_version_path(path: str) -> Dict[str,str]: + def parse_version_path(path: str) -> Dict[str, str]: """Parse a version path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -342,7 +382,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -352,7 +394,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -364,7 +408,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -376,8 +422,10 @@ def __init__(self, *, if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -396,14 +444,15 @@ def __init__(self, *, client_info=client_info, ) - def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesPager: + def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -444,8 +493,10 @@ def search_migratable_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.SearchMigratableResourcesRequest. @@ -462,45 +513,40 @@ def search_migratable_resources(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] + rpc = self._transport._wrapped_methods[ + self._transport.search_migratable_resources + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchMigratableResourcesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -549,8 +595,10 @@ def batch_migrate_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.BatchMigrateResourcesRequest. @@ -574,18 +622,11 @@ def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation.from_gapic( @@ -599,21 +640,14 @@ def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceClient', -) +__all__ = ("MigrationServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1/services/migration_service/pagers.py index 08654cbf6e..02a46451df 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/migration_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import migratable_resource from google.cloud.aiplatform_v1.types import migration_service @@ -38,12 +47,15 @@ class SearchMigratableResourcesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., migration_service.SearchMigratableResourcesResponse], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: yield from page.migratable_resources def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class SearchMigratableResourcesAsyncPager: @@ -97,12 +109,17 @@ class SearchMigratableResourcesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[migration_service.SearchMigratableResourcesResponse]], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[migration_service.SearchMigratableResourcesResponse] + ], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +141,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + async def pages( + self, + ) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +159,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py index 9fb765fdcc..38c72756f6 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] -_transport_registry['grpc'] = MigrationServiceGrpcTransport -_transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = MigrationServiceGrpcTransport +_transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport __all__ = ( - 'MigrationServiceTransport', - 'MigrationServiceGrpcTransport', - 'MigrationServiceGrpcAsyncIOTransport', + "MigrationServiceTransport", + "MigrationServiceGrpcTransport", + "MigrationServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py index 4f31e9b243..f10e4627c6 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -33,29 +33,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class MigrationServiceTransport(abc.ABC): """Abstract transport class for MigrationService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -78,8 +78,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -88,17 +88,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -116,7 +118,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -125,24 +126,25 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def search_migratable_resources(self) -> typing.Callable[ - [migration_service.SearchMigratableResourcesRequest], - typing.Union[ - migration_service.SearchMigratableResourcesResponse, - typing.Awaitable[migration_service.SearchMigratableResourcesResponse] - ]]: + def search_migratable_resources( + self, + ) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse], + ], + ]: raise NotImplementedError() @property - def batch_migrate_resources(self) -> typing.Callable[ - [migration_service.BatchMigrateResourcesRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def batch_migrate_resources( + self, + ) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'MigrationServiceTransport', -) +__all__ = ("MigrationServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py index 49659f9b31..b8cdb273a1 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -47,21 +47,24 @@ class MigrationServiceGrpcTransport(MigrationServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -173,13 +176,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -212,7 +217,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -230,17 +235,18 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def search_migratable_resources(self) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - migration_service.SearchMigratableResourcesResponse]: + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse, + ]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -258,18 +264,20 @@ def search_migratable_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_migratable_resources' not in self._stubs: - self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources', + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources", request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs['search_migratable_resources'] + return self._stubs["search_migratable_resources"] @property - def batch_migrate_resources(self) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - operations.Operation]: + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], operations.Operation + ]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -286,15 +294,13 @@ def batch_migrate_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_migrate_resources' not in self._stubs: - self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources', + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources", request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_migrate_resources'] + return self._stubs["batch_migrate_resources"] -__all__ = ( - 'MigrationServiceGrpcTransport', -) +__all__ = ("MigrationServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py index 600f8893fe..190f45eac1 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import migration_service @@ -54,13 +54,15 @@ class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -89,22 +91,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -243,9 +247,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def search_migratable_resources(self) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - Awaitable[migration_service.SearchMigratableResourcesResponse]]: + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse], + ]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -263,18 +270,21 @@ def search_migratable_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_migratable_resources' not in self._stubs: - self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources', + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources", request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs['search_migratable_resources'] + return self._stubs["search_migratable_resources"] @property - def batch_migrate_resources(self) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - Awaitable[operations.Operation]]: + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -291,15 +301,13 @@ def batch_migrate_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_migrate_resources' not in self._stubs: - self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources', + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources", request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_migrate_resources'] + return self._stubs["batch_migrate_resources"] -__all__ = ( - 'MigrationServiceGrpcAsyncIOTransport', -) +__all__ = ("MigrationServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/model_service/__init__.py b/google/cloud/aiplatform_v1/services/model_service/__init__.py index 3ee8fc6e9e..b39295ebfe 100644 --- a/google/cloud/aiplatform_v1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/model_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import ModelServiceAsyncClient __all__ = ( - 'ModelServiceClient', - 'ModelServiceAsyncClient', + "ModelServiceClient", + "ModelServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index a65c5df60f..687c22455a 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -62,26 +62,44 @@ class ModelServiceAsyncClient: model_path = staticmethod(ModelServiceClient.model_path) parse_model_path = staticmethod(ModelServiceClient.parse_model_path) model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) - parse_model_evaluation_path = staticmethod(ModelServiceClient.parse_model_evaluation_path) - model_evaluation_slice_path = staticmethod(ModelServiceClient.model_evaluation_slice_path) - parse_model_evaluation_slice_path = staticmethod(ModelServiceClient.parse_model_evaluation_slice_path) + parse_model_evaluation_path = staticmethod( + ModelServiceClient.parse_model_evaluation_path + ) + model_evaluation_slice_path = staticmethod( + ModelServiceClient.model_evaluation_slice_path + ) + parse_model_evaluation_slice_path = staticmethod( + ModelServiceClient.parse_model_evaluation_slice_path + ) training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod(ModelServiceClient.parse_training_pipeline_path) + parse_training_pipeline_path = staticmethod( + ModelServiceClient.parse_training_pipeline_path + ) - common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + ModelServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(ModelServiceClient.common_folder_path) parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) common_organization_path = staticmethod(ModelServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + ModelServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(ModelServiceClient.common_project_path) - parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + ModelServiceClient.parse_common_project_path + ) common_location_path = staticmethod(ModelServiceClient.common_location_path) - parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + ModelServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -124,14 +142,18 @@ def transport(self) -> ModelServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) + get_transport_class = functools.partial( + type(ModelServiceClient).get_transport_class, type(ModelServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -170,18 +192,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Uploads a Model artifact into AI Platform. Args: @@ -224,8 +246,10 @@ async def upload_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.UploadModelRequest(request) @@ -248,18 +272,11 @@ async def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -272,14 +289,15 @@ async def upload_model(self, # Done; return the response. return response - async def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + async def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -309,8 +327,10 @@ async def get_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelRequest(request) @@ -331,30 +351,24 @@ async def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: + async def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: r"""Lists Models in a Location. Args: @@ -390,8 +404,10 @@ async def list_models(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelsRequest(request) @@ -412,40 +428,31 @@ async def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + async def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -483,8 +490,10 @@ async def update_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.UpdateModelRequest(request) @@ -507,30 +516,26 @@ async def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -578,8 +583,10 @@ async def delete_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.DeleteModelRequest(request) @@ -600,18 +607,11 @@ async def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -624,15 +624,16 @@ async def delete_model(self, # Done; return the response. return response - async def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -680,8 +681,10 @@ async def export_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ExportModelRequest(request) @@ -704,18 +707,11 @@ async def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -728,14 +724,15 @@ async def export_model(self, # Done; return the response. return response - async def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + async def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -771,8 +768,10 @@ async def get_model_evaluation(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelEvaluationRequest(request) @@ -793,30 +792,24 @@ async def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsAsyncPager: + async def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: r"""Lists ModelEvaluations in a Model. Args: @@ -852,8 +845,10 @@ async def list_model_evaluations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelEvaluationsRequest(request) @@ -874,39 +869,30 @@ async def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + async def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -942,8 +928,10 @@ async def get_model_evaluation_slice(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelEvaluationSliceRequest(request) @@ -964,30 +952,24 @@ async def get_model_evaluation_slice(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesAsyncPager: + async def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1024,8 +1006,10 @@ async def list_model_evaluation_slices(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelEvaluationSlicesRequest(request) @@ -1046,47 +1030,30 @@ async def list_model_evaluation_slices(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationSlicesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceAsyncClient', -) +__all__ = ("ModelServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index 9d5ebc8008..fa75f3c22b 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,13 +60,12 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry['grpc'] = ModelServiceGrpcTransport - _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[ModelServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +116,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +151,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: ModelServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,121 +167,162 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: + def model_evaluation_path( + project: str, location: str, model: str, evaluation: str, + ) -> str: """Return a fully-qualified model_evaluation string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) @staticmethod - def parse_model_evaluation_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_path(path: str) -> Dict[str, str]: """Parse a model_evaluation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: + def model_evaluation_slice_path( + project: str, location: str, model: str, evaluation: str, slice: str, + ) -> str: """Return a fully-qualified model_evaluation_slice string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) @staticmethod - def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: """Parse a model_evaluation_slice path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -327,7 +366,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -337,7 +378,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -349,7 +392,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -361,8 +406,10 @@ def __init__(self, *, if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -381,15 +428,16 @@ def __init__(self, *, client_info=client_info, ) - def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -432,8 +480,10 @@ def upload_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UploadModelRequest. @@ -457,18 +507,11 @@ def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -481,14 +524,15 @@ def upload_model(self, # Done; return the response. return response - def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -518,8 +562,10 @@ def get_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -541,30 +587,24 @@ def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -600,8 +640,10 @@ def list_models(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -623,40 +665,31 @@ def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -694,8 +727,10 @@ def update_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateModelRequest. @@ -719,30 +754,26 @@ def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -790,8 +821,10 @@ def delete_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteModelRequest. @@ -813,18 +846,11 @@ def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -837,15 +863,16 @@ def delete_model(self, # Done; return the response. return response - def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -893,8 +920,10 @@ def export_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ExportModelRequest. @@ -918,18 +947,11 @@ def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -942,14 +964,15 @@ def export_model(self, # Done; return the response. return response - def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -985,8 +1008,10 @@ def get_model_evaluation(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationRequest. @@ -1008,30 +1033,24 @@ def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -1067,8 +1086,10 @@ def list_model_evaluations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationsRequest. @@ -1090,39 +1111,30 @@ def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -1158,8 +1170,10 @@ def get_model_evaluation_slice(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationSliceRequest. @@ -1176,35 +1190,31 @@ def get_model_evaluation_slice(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] + rpc = self._transport._wrapped_methods[ + self._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1241,8 +1251,10 @@ def list_model_evaluation_slices(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationSlicesRequest. @@ -1259,52 +1271,37 @@ def list_model_evaluation_slices(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] + rpc = self._transport._wrapped_methods[ + self._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceClient', -) +__all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/model_service/pagers.py b/google/cloud/aiplatform_v1/services/model_service/pagers.py index cf94a17fea..d01f0057c1 100644 --- a/google/cloud/aiplatform_v1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/model_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import model from google.cloud.aiplatform_v1.types import model_evaluation @@ -40,12 +49,15 @@ class ListModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelsResponse], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -79,7 +91,7 @@ def __iter__(self) -> Iterable[model.Model]: yield from page.models def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelsAsyncPager: @@ -99,12 +111,15 @@ class ListModelsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -142,7 +157,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationsPager: @@ -162,12 +177,15 @@ class ListModelEvaluationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelEvaluationsResponse], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationsResponse], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -201,7 +219,7 @@ def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: yield from page.model_evaluations def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationsAsyncPager: @@ -221,12 +239,15 @@ class ListModelEvaluationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -264,7 +285,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesPager: @@ -284,12 +305,15 @@ class ListModelEvaluationSlicesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelEvaluationSlicesResponse], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -323,7 +347,7 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: yield from page.model_evaluation_slices def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesAsyncPager: @@ -343,12 +367,17 @@ class ListModelEvaluationSlicesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationSlicesResponse]], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] + ], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -370,7 +399,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + async def pages( + self, + ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -386,4 +417,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py index 833862a1d6..5d1cb51abc 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry['grpc'] = ModelServiceGrpcTransport -_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport __all__ = ( - 'ModelServiceTransport', - 'ModelServiceGrpcTransport', - 'ModelServiceGrpcAsyncIOTransport', + "ModelServiceTransport", + "ModelServiceGrpcTransport", + "ModelServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1/services/model_service/transports/base.py index 80c34f3e4a..29a59f30da 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -37,29 +37,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -82,8 +82,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -92,17 +92,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -111,34 +113,22 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.upload_model: gapic_v1.method.wrap_method( - self.upload_model, - default_timeout=None, - client_info=client_info, + self.upload_model, default_timeout=None, client_info=client_info, ), self.get_model: gapic_v1.method.wrap_method( - self.get_model, - default_timeout=None, - client_info=client_info, + self.get_model, default_timeout=None, client_info=client_info, ), self.list_models: gapic_v1.method.wrap_method( - self.list_models, - default_timeout=None, - client_info=client_info, + self.list_models, default_timeout=None, client_info=client_info, ), self.update_model: gapic_v1.method.wrap_method( - self.update_model, - default_timeout=None, - client_info=client_info, + self.update_model, default_timeout=None, client_info=client_info, ), self.delete_model: gapic_v1.method.wrap_method( - self.delete_model, - default_timeout=None, - client_info=client_info, + self.delete_model, default_timeout=None, client_info=client_info, ), self.export_model: gapic_v1.method.wrap_method( - self.export_model, - default_timeout=None, - client_info=client_info, + self.export_model, default_timeout=None, client_info=client_info, ), self.get_model_evaluation: gapic_v1.method.wrap_method( self.get_model_evaluation, @@ -160,7 +150,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -169,96 +158,109 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def upload_model(self) -> typing.Callable[ - [model_service.UploadModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def upload_model( + self, + ) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model(self) -> typing.Callable[ - [model_service.GetModelRequest], - typing.Union[ - model.Model, - typing.Awaitable[model.Model] - ]]: + def get_model( + self, + ) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[model.Model, typing.Awaitable[model.Model]], + ]: raise NotImplementedError() @property - def list_models(self) -> typing.Callable[ - [model_service.ListModelsRequest], - typing.Union[ - model_service.ListModelsResponse, - typing.Awaitable[model_service.ListModelsResponse] - ]]: + def list_models( + self, + ) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse], + ], + ]: raise NotImplementedError() @property - def update_model(self) -> typing.Callable[ - [model_service.UpdateModelRequest], - typing.Union[ - gca_model.Model, - typing.Awaitable[gca_model.Model] - ]]: + def update_model( + self, + ) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], + ]: raise NotImplementedError() @property - def delete_model(self) -> typing.Callable[ - [model_service.DeleteModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_model( + self, + ) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_model(self) -> typing.Callable[ - [model_service.ExportModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_model( + self, + ) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model_evaluation(self) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], - typing.Union[ - model_evaluation.ModelEvaluation, - typing.Awaitable[model_evaluation.ModelEvaluation] - ]]: + def get_model_evaluation( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation], + ], + ]: raise NotImplementedError() @property - def list_model_evaluations(self) -> typing.Callable[ - [model_service.ListModelEvaluationsRequest], - typing.Union[ - model_service.ListModelEvaluationsResponse, - typing.Awaitable[model_service.ListModelEvaluationsResponse] - ]]: + def list_model_evaluations( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse], + ], + ]: raise NotImplementedError() @property - def get_model_evaluation_slice(self) -> typing.Callable[ - [model_service.GetModelEvaluationSliceRequest], - typing.Union[ - model_evaluation_slice.ModelEvaluationSlice, - typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice] - ]]: + def get_model_evaluation_slice( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ], + ]: raise NotImplementedError() @property - def list_model_evaluation_slices(self) -> typing.Callable[ - [model_service.ListModelEvaluationSlicesRequest], - typing.Union[ - model_service.ListModelEvaluationSlicesResponse, - typing.Awaitable[model_service.ListModelEvaluationSlicesResponse] - ]]: + def list_model_evaluation_slices( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'ModelServiceTransport', -) +__all__ = ("ModelServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py index d05154e2fb..92015d0848 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -49,21 +49,24 @@ class ModelServiceGrpcTransport(ModelServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -175,13 +178,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -214,7 +219,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -232,17 +237,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def upload_model(self) -> Callable[ - [model_service.UploadModelRequest], - operations.Operation]: + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], operations.Operation]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -257,18 +260,16 @@ def upload_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'upload_model' not in self._stubs: - self._stubs['upload_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/UploadModel', + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UploadModel", request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['upload_model'] + return self._stubs["upload_model"] @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - model.Model]: + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -283,18 +284,18 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - model_service.ListModelsResponse]: + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -309,18 +310,18 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def update_model(self) -> Callable[ - [model_service.UpdateModelRequest], - gca_model.Model]: + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], gca_model.Model]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -335,18 +336,18 @@ def update_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model' not in self._stubs: - self._stubs['update_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/UpdateModel', + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UpdateModel", request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs['update_model'] + return self._stubs["update_model"] @property - def delete_model(self) -> Callable[ - [model_service.DeleteModelRequest], - operations.Operation]: + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], operations.Operation]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -363,18 +364,18 @@ def delete_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model' not in self._stubs: - self._stubs['delete_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/DeleteModel', + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/DeleteModel", request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model'] + return self._stubs["delete_model"] @property - def export_model(self) -> Callable[ - [model_service.ExportModelRequest], - operations.Operation]: + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], operations.Operation]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -392,18 +393,20 @@ def export_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_model' not in self._stubs: - self._stubs['export_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ExportModel', + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ExportModel", request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_model'] + return self._stubs["export_model"] @property - def get_model_evaluation(self) -> Callable[ - [model_service.GetModelEvaluationRequest], - model_evaluation.ModelEvaluation]: + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation + ]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -418,18 +421,21 @@ def get_model_evaluation(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation' not in self._stubs: - self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation', + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation", request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs['get_model_evaluation'] + return self._stubs["get_model_evaluation"] @property - def list_model_evaluations(self) -> Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse]: + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse, + ]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -444,18 +450,21 @@ def list_model_evaluations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluations' not in self._stubs: - self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations', + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations", request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs['list_model_evaluations'] + return self._stubs["list_model_evaluations"] @property - def get_model_evaluation_slice(self) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice]: + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + model_evaluation_slice.ModelEvaluationSlice, + ]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -470,18 +479,21 @@ def get_model_evaluation_slice(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation_slice' not in self._stubs: - self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice', + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice", request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs['get_model_evaluation_slice'] + return self._stubs["get_model_evaluation_slice"] @property - def list_model_evaluation_slices(self) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse]: + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse, + ]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -496,15 +508,13 @@ def list_model_evaluation_slices(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluation_slices' not in self._stubs: - self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices', + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices", request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs['list_model_evaluation_slices'] + return self._stubs["list_model_evaluation_slices"] -__all__ = ( - 'ModelServiceGrpcTransport', -) +__all__ = ("ModelServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py index 1e24fe3d5c..2de86d2623 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import model @@ -56,13 +56,15 @@ class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -91,22 +93,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -245,9 +249,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def upload_model(self) -> Callable[ - [model_service.UploadModelRequest], - Awaitable[operations.Operation]]: + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -262,18 +266,18 @@ def upload_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'upload_model' not in self._stubs: - self._stubs['upload_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/UploadModel', + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UploadModel", request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['upload_model'] + return self._stubs["upload_model"] @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - Awaitable[model.Model]]: + def get_model( + self, + ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -288,18 +292,20 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - Awaitable[model_service.ListModelsResponse]]: + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] + ]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -314,18 +320,18 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def update_model(self) -> Callable[ - [model_service.UpdateModelRequest], - Awaitable[gca_model.Model]]: + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -340,18 +346,18 @@ def update_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model' not in self._stubs: - self._stubs['update_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/UpdateModel', + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UpdateModel", request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs['update_model'] + return self._stubs["update_model"] @property - def delete_model(self) -> Callable[ - [model_service.DeleteModelRequest], - Awaitable[operations.Operation]]: + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -368,18 +374,18 @@ def delete_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model' not in self._stubs: - self._stubs['delete_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/DeleteModel', + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/DeleteModel", request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model'] + return self._stubs["delete_model"] @property - def export_model(self) -> Callable[ - [model_service.ExportModelRequest], - Awaitable[operations.Operation]]: + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -397,18 +403,21 @@ def export_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_model' not in self._stubs: - self._stubs['export_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ExportModel', + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ExportModel", request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_model'] + return self._stubs["export_model"] @property - def get_model_evaluation(self) -> Callable[ - [model_service.GetModelEvaluationRequest], - Awaitable[model_evaluation.ModelEvaluation]]: + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation], + ]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -423,18 +432,21 @@ def get_model_evaluation(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation' not in self._stubs: - self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation', + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation", request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs['get_model_evaluation'] + return self._stubs["get_model_evaluation"] @property - def list_model_evaluations(self) -> Callable[ - [model_service.ListModelEvaluationsRequest], - Awaitable[model_service.ListModelEvaluationsResponse]]: + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse], + ]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -449,18 +461,21 @@ def list_model_evaluations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluations' not in self._stubs: - self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations', + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations", request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs['list_model_evaluations'] + return self._stubs["list_model_evaluations"] @property - def get_model_evaluation_slice(self) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - Awaitable[model_evaluation_slice.ModelEvaluationSlice]]: + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -475,18 +490,21 @@ def get_model_evaluation_slice(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation_slice' not in self._stubs: - self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice', + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice", request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs['get_model_evaluation_slice'] + return self._stubs["get_model_evaluation_slice"] @property - def list_model_evaluation_slices(self) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - Awaitable[model_service.ListModelEvaluationSlicesResponse]]: + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse], + ]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -501,15 +519,13 @@ def list_model_evaluation_slices(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluation_slices' not in self._stubs: - self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices', + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices", request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs['list_model_evaluation_slices'] + return self._stubs["list_model_evaluation_slices"] -__all__ = ( - 'ModelServiceGrpcAsyncIOTransport', -) +__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py index f7f4d9b9ac..7f02b47358 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PipelineServiceAsyncClient __all__ = ( - 'PipelineServiceClient', - 'PipelineServiceAsyncClient', + "PipelineServiceClient", + "PipelineServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index 276c0980f5..fc7337a7a3 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -61,22 +61,38 @@ class PipelineServiceAsyncClient: model_path = staticmethod(PipelineServiceClient.model_path) parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod(PipelineServiceClient.parse_training_pipeline_path) + parse_training_pipeline_path = staticmethod( + PipelineServiceClient.parse_training_pipeline_path + ) - common_billing_account_path = staticmethod(PipelineServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(PipelineServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + PipelineServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PipelineServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(PipelineServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + PipelineServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(PipelineServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(PipelineServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + PipelineServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PipelineServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(PipelineServiceClient.common_project_path) - parse_common_project_path = staticmethod(PipelineServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + PipelineServiceClient.parse_common_project_path + ) common_location_path = staticmethod(PipelineServiceClient.common_location_path) - parse_common_location_path = staticmethod(PipelineServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + PipelineServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -119,14 +135,18 @@ def transport(self) -> PipelineServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient)) + get_transport_class = functools.partial( + type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -165,18 +185,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + async def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -221,8 +241,10 @@ async def create_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.CreateTrainingPipelineRequest(request) @@ -245,30 +267,24 @@ async def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + async def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -306,8 +322,10 @@ async def get_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.GetTrainingPipelineRequest(request) @@ -328,30 +346,24 @@ async def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesAsyncPager: + async def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: r"""Lists TrainingPipelines in a Location. Args: @@ -387,8 +399,10 @@ async def list_training_pipelines(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.ListTrainingPipelinesRequest(request) @@ -409,39 +423,30 @@ async def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListTrainingPipelinesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a TrainingPipeline. Args: @@ -488,8 +493,10 @@ async def delete_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.DeleteTrainingPipelineRequest(request) @@ -510,18 +517,11 @@ async def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -534,14 +534,15 @@ async def delete_training_pipeline(self, # Done; return the response. return response - async def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -581,8 +582,10 @@ async def cancel_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.CancelTrainingPipelineRequest(request) @@ -603,35 +606,23 @@ async def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceAsyncClient', -) +__all__ = ("PipelineServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index fe36174dda..39f37eb72e 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -59,13 +59,14 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry['grpc'] = PipelineServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PipelineServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry["grpc"] = PipelineServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -116,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -151,9 +152,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PipelineServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -168,99 +168,122 @@ def transport(self) -> PipelineServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -304,7 +327,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -314,7 +339,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -326,7 +353,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -338,8 +367,10 @@ def __init__(self, *, if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -358,15 +389,16 @@ def __init__(self, *, client_info=client_info, ) - def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -411,8 +443,10 @@ def create_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CreateTrainingPipelineRequest. @@ -436,30 +470,24 @@ def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -497,8 +525,10 @@ def get_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.GetTrainingPipelineRequest. @@ -520,30 +550,24 @@ def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -579,8 +603,10 @@ def list_training_pipelines(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.ListTrainingPipelinesRequest. @@ -602,39 +628,30 @@ def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -681,8 +698,10 @@ def delete_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.DeleteTrainingPipelineRequest. @@ -704,18 +723,11 @@ def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -728,14 +740,15 @@ def delete_training_pipeline(self, # Done; return the response. return response - def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -775,8 +788,10 @@ def cancel_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CancelTrainingPipelineRequest. @@ -798,35 +813,23 @@ def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceClient', -) +__all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py index ec626400ec..987c37dba2 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import pipeline_service from google.cloud.aiplatform_v1.types import training_pipeline @@ -38,12 +47,15 @@ class ListTrainingPipelinesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: yield from page.training_pipelines def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListTrainingPipelinesAsyncPager: @@ -97,12 +109,17 @@ class ListTrainingPipelinesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[pipeline_service.ListTrainingPipelinesResponse]], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +141,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + async def pages( + self, + ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +159,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py index f289718f83..9d4610087a 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] -_transport_registry['grpc'] = PipelineServiceGrpcTransport -_transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = PipelineServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport __all__ = ( - 'PipelineServiceTransport', - 'PipelineServiceGrpcTransport', - 'PipelineServiceGrpcAsyncIOTransport', + "PipelineServiceTransport", + "PipelineServiceGrpcTransport", + "PipelineServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py index 3a0cfa5a08..901318016a 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,8 +81,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -91,17 +91,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -134,7 +136,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -143,51 +144,58 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_training_pipeline(self) -> typing.Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - typing.Union[ - gca_training_pipeline.TrainingPipeline, - typing.Awaitable[gca_training_pipeline.TrainingPipeline] - ]]: + def create_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline], + ], + ]: raise NotImplementedError() @property - def get_training_pipeline(self) -> typing.Callable[ - [pipeline_service.GetTrainingPipelineRequest], - typing.Union[ - training_pipeline.TrainingPipeline, - typing.Awaitable[training_pipeline.TrainingPipeline] - ]]: + def get_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.GetTrainingPipelineRequest], + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline], + ], + ]: raise NotImplementedError() @property - def list_training_pipelines(self) -> typing.Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - typing.Union[ - pipeline_service.ListTrainingPipelinesResponse, - typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse] - ]]: + def list_training_pipelines( + self, + ) -> typing.Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ], + ]: raise NotImplementedError() @property - def delete_training_pipeline(self) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_training_pipeline(self) -> typing.Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() -__all__ = ( - 'PipelineServiceTransport', -) +__all__ = ("PipelineServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py index 4f19145175..abf870cde5 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -48,21 +48,24 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -174,13 +177,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -213,7 +218,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -231,17 +236,18 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_training_pipeline(self) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline]: + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + gca_training_pipeline.TrainingPipeline, + ]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -257,18 +263,21 @@ def create_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_training_pipeline' not in self._stubs: - self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline', + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline", request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['create_training_pipeline'] + return self._stubs["create_training_pipeline"] @property - def get_training_pipeline(self) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline]: + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + training_pipeline.TrainingPipeline, + ]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -283,18 +292,21 @@ def get_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_training_pipeline' not in self._stubs: - self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline', + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline", request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['get_training_pipeline'] + return self._stubs["get_training_pipeline"] @property - def list_training_pipelines(self) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse]: + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse, + ]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -309,18 +321,20 @@ def list_training_pipelines(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_training_pipelines' not in self._stubs: - self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines', + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines", request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs['list_training_pipelines'] + return self._stubs["list_training_pipelines"] @property - def delete_training_pipeline(self) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - operations.Operation]: + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation + ]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -335,18 +349,18 @@ def delete_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_training_pipeline' not in self._stubs: - self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline', + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline", request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_training_pipeline'] + return self._stubs["delete_training_pipeline"] @property - def cancel_training_pipeline(self) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - empty.Empty]: + def cancel_training_pipeline( + self, + ) -> Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -373,15 +387,13 @@ def cancel_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_training_pipeline' not in self._stubs: - self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline', + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline", request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_training_pipeline'] + return self._stubs["cancel_training_pipeline"] -__all__ = ( - 'PipelineServiceGrpcTransport', -) +__all__ = ("PipelineServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py index 8a0f1f7534..01f93c5600 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import pipeline_service @@ -55,13 +55,15 @@ class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -90,22 +92,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -244,9 +248,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_training_pipeline(self) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - Awaitable[gca_training_pipeline.TrainingPipeline]]: + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline], + ]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -262,18 +269,21 @@ def create_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_training_pipeline' not in self._stubs: - self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline', + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline", request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['create_training_pipeline'] + return self._stubs["create_training_pipeline"] @property - def get_training_pipeline(self) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - Awaitable[training_pipeline.TrainingPipeline]]: + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline], + ]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -288,18 +298,21 @@ def get_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_training_pipeline' not in self._stubs: - self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline', + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline", request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['get_training_pipeline'] + return self._stubs["get_training_pipeline"] @property - def list_training_pipelines(self) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - Awaitable[pipeline_service.ListTrainingPipelinesResponse]]: + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -314,18 +327,21 @@ def list_training_pipelines(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_training_pipelines' not in self._stubs: - self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines', + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines", request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs['list_training_pipelines'] + return self._stubs["list_training_pipelines"] @property - def delete_training_pipeline(self) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - Awaitable[operations.Operation]]: + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -340,18 +356,20 @@ def delete_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_training_pipeline' not in self._stubs: - self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline', + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline", request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_training_pipeline'] + return self._stubs["delete_training_pipeline"] @property - def cancel_training_pipeline(self) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - Awaitable[empty.Empty]]: + def cancel_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -378,15 +396,13 @@ def cancel_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_training_pipeline' not in self._stubs: - self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline', + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline", request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_training_pipeline'] + return self._stubs["cancel_training_pipeline"] -__all__ = ( - 'PipelineServiceGrpcAsyncIOTransport', -) +__all__ = ("PipelineServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1/services/prediction_service/__init__.py index d4047c335d..0c847693e0 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PredictionServiceAsyncClient __all__ = ( - 'PredictionServiceClient', - 'PredictionServiceAsyncClient', + "PredictionServiceClient", + "PredictionServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index 299694bdce..cc6d011e88 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import prediction_service from google.protobuf import struct_pb2 as struct # type: ignore @@ -47,20 +47,34 @@ class PredictionServiceAsyncClient: endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) - common_billing_account_path = staticmethod(PredictionServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(PredictionServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + PredictionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PredictionServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(PredictionServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + PredictionServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(PredictionServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(PredictionServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + PredictionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PredictionServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(PredictionServiceClient.common_project_path) - parse_common_project_path = staticmethod(PredictionServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + PredictionServiceClient.parse_common_project_path + ) common_location_path = staticmethod(PredictionServiceClient.common_location_path) - parse_common_location_path = staticmethod(PredictionServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + PredictionServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -103,14 +117,18 @@ def transport(self) -> PredictionServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient)) + get_transport_class = functools.partial( + type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -149,19 +167,19 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def predict(self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + async def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -221,8 +239,10 @@ async def predict(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = prediction_service.PredictRequest(request) @@ -248,38 +268,24 @@ async def predict(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PredictionServiceAsyncClient', -) +__all__ = ("PredictionServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py index 7d9294a251..029fb851b8 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import prediction_service from google.protobuf import struct_pb2 as struct # type: ignore @@ -47,13 +47,16 @@ class PredictionServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] - _transport_registry['grpc'] = PredictionServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PredictionServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry["grpc"] = PredictionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[PredictionServiceTransport]: """Return an appropriate transport class. Args: @@ -104,7 +107,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -139,9 +142,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PredictionServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -156,77 +158,88 @@ def transport(self) -> PredictionServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -270,7 +283,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -280,7 +295,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -292,7 +309,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -304,8 +323,10 @@ def __init__(self, *, if isinstance(transport, PredictionServiceTransport): # transport is a PredictionServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -324,16 +345,17 @@ def __init__(self, *, client_info=client_info, ) - def predict(self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -393,8 +415,10 @@ def predict(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a prediction_service.PredictRequest. @@ -420,38 +444,24 @@ def predict(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PredictionServiceClient', -) +__all__ = ("PredictionServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py index 15b5acb198..9ec1369a05 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] -_transport_registry['grpc'] = PredictionServiceGrpcTransport -_transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = PredictionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport __all__ = ( - 'PredictionServiceTransport', - 'PredictionServiceGrpcTransport', - 'PredictionServiceGrpcAsyncIOTransport', + "PredictionServiceTransport", + "PredictionServiceGrpcTransport", + "PredictionServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py index 9e8a9841c0..fb30d2533e 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore @@ -31,29 +31,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -76,8 +76,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -86,17 +86,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -105,23 +107,21 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.predict: gapic_v1.method.wrap_method( - self.predict, - default_timeout=None, - client_info=client_info, + self.predict, default_timeout=None, client_info=client_info, ), - } @property - def predict(self) -> typing.Callable[ - [prediction_service.PredictRequest], - typing.Union[ - prediction_service.PredictResponse, - typing.Awaitable[prediction_service.PredictResponse] - ]]: + def predict( + self, + ) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'PredictionServiceTransport', -) +__all__ = ("PredictionServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py index 484a1193b1..f78e11bd2d 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -18,10 +18,10 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -43,21 +43,24 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -168,13 +171,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -207,7 +212,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -217,9 +222,11 @@ def grpc_channel(self) -> grpc.Channel: return self._grpc_channel @property - def predict(self) -> Callable[ - [prediction_service.PredictRequest], - prediction_service.PredictResponse]: + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], prediction_service.PredictResponse + ]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -234,15 +241,13 @@ def predict(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'predict' not in self._stubs: - self._stubs['predict'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PredictionService/Predict', + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PredictionService/Predict", request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs['predict'] + return self._stubs["predict"] -__all__ = ( - 'PredictionServiceGrpcTransport', -) +__all__ = ("PredictionServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py index 87a9970365..c9d5e2ba94 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py @@ -18,13 +18,13 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import prediction_service @@ -50,13 +50,15 @@ class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -85,22 +87,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -222,9 +226,12 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def predict(self) -> Callable[ - [prediction_service.PredictRequest], - Awaitable[prediction_service.PredictResponse]]: + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse], + ]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -239,15 +246,13 @@ def predict(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'predict' not in self._stubs: - self._stubs['predict'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.PredictionService/Predict', + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PredictionService/Predict", request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs['predict'] + return self._stubs["predict"] -__all__ = ( - 'PredictionServiceGrpcAsyncIOTransport', -) +__all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py index e4247d7758..49e9cdf0a0 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import SpecialistPoolServiceAsyncClient __all__ = ( - 'SpecialistPoolServiceClient', - 'SpecialistPoolServiceAsyncClient', + "SpecialistPoolServiceClient", + "SpecialistPoolServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index be193ead83..57e2b8a0a7 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -57,23 +57,43 @@ class SpecialistPoolServiceAsyncClient: DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT - specialist_pool_path = staticmethod(SpecialistPoolServiceClient.specialist_pool_path) - parse_specialist_pool_path = staticmethod(SpecialistPoolServiceClient.parse_specialist_pool_path) + specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.specialist_pool_path + ) + parse_specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.parse_specialist_pool_path + ) - common_billing_account_path = staticmethod(SpecialistPoolServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(SpecialistPoolServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(SpecialistPoolServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + SpecialistPoolServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(SpecialistPoolServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(SpecialistPoolServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + SpecialistPoolServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + SpecialistPoolServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) - parse_common_project_path = staticmethod(SpecialistPoolServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + SpecialistPoolServiceClient.parse_common_project_path + ) - common_location_path = staticmethod(SpecialistPoolServiceClient.common_location_path) - parse_common_location_path = staticmethod(SpecialistPoolServiceClient.parse_common_location_path) + common_location_path = staticmethod( + SpecialistPoolServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + SpecialistPoolServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -116,14 +136,19 @@ def transport(self) -> SpecialistPoolServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(SpecialistPoolServiceClient).get_transport_class, type(SpecialistPoolServiceClient)) + get_transport_class = functools.partial( + type(SpecialistPoolServiceClient).get_transport_class, + type(SpecialistPoolServiceClient), + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -162,18 +187,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a SpecialistPool. Args: @@ -221,8 +246,10 @@ async def create_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.CreateSpecialistPoolRequest(request) @@ -245,18 +272,11 @@ async def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -269,14 +289,15 @@ async def create_specialist_pool(self, # Done; return the response. return response - async def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + async def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -319,8 +340,10 @@ async def get_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.GetSpecialistPoolRequest(request) @@ -341,30 +364,24 @@ async def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsAsyncPager: + async def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: r"""Lists SpecialistPools in a Location. Args: @@ -400,8 +417,10 @@ async def list_specialist_pools(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.ListSpecialistPoolsRequest(request) @@ -422,39 +441,30 @@ async def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListSpecialistPoolsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -501,8 +511,10 @@ async def delete_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.DeleteSpecialistPoolRequest(request) @@ -523,18 +535,11 @@ async def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -547,15 +552,16 @@ async def delete_specialist_pool(self, # Done; return the response. return response - async def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates a SpecialistPool. Args: @@ -602,8 +608,10 @@ async def update_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.UpdateSpecialistPoolRequest(request) @@ -626,18 +634,13 @@ async def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -651,21 +654,14 @@ async def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceAsyncClient', -) +__all__ = ("SpecialistPoolServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index efb32eaa6e..c6429b54f8 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -54,13 +54,16 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport - _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport + _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +120,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +155,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: SpecialistPoolServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,77 +171,88 @@ def transport(self) -> SpecialistPoolServiceTransport: return self._transport @staticmethod - def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: + def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str,str]: + def parse_specialist_pool_path(path: str) -> Dict[str, str]: """Parse a specialist_pool path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -283,7 +296,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -293,7 +308,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -305,7 +322,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -317,8 +336,10 @@ def __init__(self, *, if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -337,15 +358,16 @@ def __init__(self, *, client_info=client_info, ) - def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -393,8 +415,10 @@ def create_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.CreateSpecialistPoolRequest. @@ -418,18 +442,11 @@ def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -442,14 +459,15 @@ def create_specialist_pool(self, # Done; return the response. return response - def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -492,8 +510,10 @@ def get_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.GetSpecialistPoolRequest. @@ -515,30 +535,24 @@ def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -574,8 +588,10 @@ def list_specialist_pools(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.ListSpecialistPoolsRequest. @@ -597,39 +613,30 @@ def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -676,8 +683,10 @@ def delete_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.DeleteSpecialistPoolRequest. @@ -699,18 +708,11 @@ def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -723,15 +725,16 @@ def delete_specialist_pool(self, # Done; return the response. return response - def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -778,8 +781,10 @@ def update_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.UpdateSpecialistPoolRequest. @@ -803,18 +808,13 @@ def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -828,21 +828,14 @@ def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceClient', -) +__all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py index 87590e0e87..e64a827049 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import specialist_pool from google.cloud.aiplatform_v1.types import specialist_pool_service @@ -38,12 +47,15 @@ class ListSpecialistPoolsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: yield from page.specialist_pools def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListSpecialistPoolsAsyncPager: @@ -97,12 +109,17 @@ class ListSpecialistPoolsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +141,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + async def pages( + self, + ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +159,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py index 80de7b209f..1bb2fbf22a 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py @@ -24,12 +24,14 @@ # Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] -_transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport -_transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[SpecialistPoolServiceTransport]] +_transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport +_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( - 'SpecialistPoolServiceTransport', - 'SpecialistPoolServiceGrpcTransport', - 'SpecialistPoolServiceGrpcAsyncIOTransport', + "SpecialistPoolServiceTransport", + "SpecialistPoolServiceGrpcTransport", + "SpecialistPoolServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py index 878e095edb..15338f63b9 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,29 +34,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -79,8 +79,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -89,17 +89,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -113,9 +115,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_specialist_pool: gapic_v1.method.wrap_method( - self.get_specialist_pool, - default_timeout=None, - client_info=client_info, + self.get_specialist_pool, default_timeout=None, client_info=client_info, ), self.list_specialist_pools: gapic_v1.method.wrap_method( self.list_specialist_pools, @@ -132,7 +132,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -141,51 +140,55 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - typing.Union[ - specialist_pool.SpecialistPool, - typing.Awaitable[specialist_pool.SpecialistPool] - ]]: + def get_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool], + ], + ]: raise NotImplementedError() @property - def list_specialist_pools(self) -> typing.Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - typing.Union[ - specialist_pool_service.ListSpecialistPoolsResponse, - typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] - ]]: + def list_specialist_pools( + self, + ) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ], + ]: raise NotImplementedError() @property - def delete_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def update_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def update_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'SpecialistPoolServiceTransport', -) +__all__ = ("SpecialistPoolServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py index 7574c12f22..97bb19e261 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,21 +51,24 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -177,13 +180,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -216,7 +221,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -234,17 +239,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_specialist_pool(self) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - operations.Operation]: + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -259,18 +264,21 @@ def create_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_specialist_pool' not in self._stubs: - self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool', + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool", request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_specialist_pool'] + return self._stubs["create_specialist_pool"] @property - def get_specialist_pool(self) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool]: + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + specialist_pool.SpecialistPool, + ]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -285,18 +293,21 @@ def get_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_specialist_pool' not in self._stubs: - self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool', + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool", request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs['get_specialist_pool'] + return self._stubs["get_specialist_pool"] @property - def list_specialist_pools(self) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse]: + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse, + ]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -311,18 +322,20 @@ def list_specialist_pools(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_specialist_pools' not in self._stubs: - self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools', + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools", request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs['list_specialist_pools'] + return self._stubs["list_specialist_pools"] @property - def delete_specialist_pool(self) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - operations.Operation]: + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -338,18 +351,20 @@ def delete_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_specialist_pool' not in self._stubs: - self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool', + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool", request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_specialist_pool'] + return self._stubs["delete_specialist_pool"] @property - def update_specialist_pool(self) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - operations.Operation]: + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -364,15 +379,13 @@ def update_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_specialist_pool' not in self._stubs: - self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool', + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool", request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_specialist_pool'] + return self._stubs["update_specialist_pool"] -__all__ = ( - 'SpecialistPoolServiceGrpcTransport', -) +__all__ = ("SpecialistPoolServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py index 2766d7848b..fd7766a767 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1.types import specialist_pool @@ -58,13 +58,15 @@ class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -93,22 +95,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -247,9 +251,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_specialist_pool(self) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -264,18 +271,21 @@ def create_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_specialist_pool' not in self._stubs: - self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool', + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool", request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_specialist_pool'] + return self._stubs["create_specialist_pool"] @property - def get_specialist_pool(self) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - Awaitable[specialist_pool.SpecialistPool]]: + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool], + ]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -290,18 +300,21 @@ def get_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_specialist_pool' not in self._stubs: - self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool', + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool", request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs['get_specialist_pool'] + return self._stubs["get_specialist_pool"] @property - def list_specialist_pools(self) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]]: + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -316,18 +329,21 @@ def list_specialist_pools(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_specialist_pools' not in self._stubs: - self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools', + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools", request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs['list_specialist_pools'] + return self._stubs["list_specialist_pools"] @property - def delete_specialist_pool(self) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -343,18 +359,21 @@ def delete_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_specialist_pool' not in self._stubs: - self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool', + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool", request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_specialist_pool'] + return self._stubs["delete_specialist_pool"] @property - def update_specialist_pool(self) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -369,15 +388,13 @@ def update_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_specialist_pool' not in self._stubs: - self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool', + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool", request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_specialist_pool'] + return self._stubs["update_specialist_pool"] -__all__ = ( - 'SpecialistPoolServiceGrpcAsyncIOTransport', -) +__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index b33ec9f9b8..6d7c9ca42f 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -15,18 +15,10 @@ # limitations under the License. # -from .annotation import ( - Annotation, -) -from .annotation_spec import ( - AnnotationSpec, -) -from .batch_prediction_job import ( - BatchPredictionJob, -) -from .completion_stats import ( - CompletionStats, -) +from .annotation import Annotation +from .annotation_spec import AnnotationSpec +from .batch_prediction_job import BatchPredictionJob +from .completion_stats import CompletionStats from .custom_job import ( ContainerSpec, CustomJob, @@ -35,9 +27,7 @@ Scheduling, WorkerPoolSpec, ) -from .data_item import ( - DataItem, -) +from .data_item import DataItem from .data_labeling_job import ( ActiveLearningConfig, DataLabelingJob, @@ -69,12 +59,8 @@ ListDatasetsResponse, UpdateDatasetRequest, ) -from .deployed_model_ref import ( - DeployedModelRef, -) -from .encryption_spec import ( - EncryptionSpec, -) +from .deployed_model_ref import DeployedModelRef +from .encryption_spec import EncryptionSpec from .endpoint import ( DeployedModel, Endpoint, @@ -94,12 +80,8 @@ UndeployModelResponse, UpdateEndpointRequest, ) -from .env_var import ( - EnvVar, -) -from .hyperparameter_tuning_job import ( - HyperparameterTuningJob, -) +from .env_var import EnvVar +from .hyperparameter_tuning_job import HyperparameterTuningJob from .io import ( BigQueryDestination, BigQuerySource, @@ -141,12 +123,8 @@ MachineSpec, ResourcesConsumed, ) -from .manual_batch_tuning_parameters import ( - ManualBatchTuningParameters, -) -from .migratable_resource import ( - MigratableResource, -) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters +from .migratable_resource import MigratableResource from .migration_service import ( BatchMigrateResourcesOperationMetadata, BatchMigrateResourcesRequest, @@ -162,12 +140,8 @@ Port, PredictSchemata, ) -from .model_evaluation import ( - ModelEvaluation, -) -from .model_evaluation_slice import ( - ModelEvaluationSlice, -) +from .model_evaluation import ModelEvaluation +from .model_evaluation_slice import ModelEvaluationSlice from .model_service import ( DeleteModelRequest, ExportModelOperationMetadata, @@ -203,9 +177,7 @@ PredictRequest, PredictResponse, ) -from .specialist_pool import ( - SpecialistPool, -) +from .specialist_pool import SpecialistPool from .specialist_pool_service import ( CreateSpecialistPoolOperationMetadata, CreateSpecialistPoolRequest, @@ -229,163 +201,161 @@ TimestampSplit, TrainingPipeline, ) -from .user_action_reference import ( - UserActionReference, -) +from .user_action_reference import UserActionReference __all__ = ( - 'AcceleratorType', - 'Annotation', - 'AnnotationSpec', - 'BatchPredictionJob', - 'CompletionStats', - 'ContainerSpec', - 'CustomJob', - 'CustomJobSpec', - 'PythonPackageSpec', - 'Scheduling', - 'WorkerPoolSpec', - 'DataItem', - 'ActiveLearningConfig', - 'DataLabelingJob', - 'SampleConfig', - 'TrainingConfig', - 'Dataset', - 'ExportDataConfig', - 'ImportDataConfig', - 'CreateDatasetOperationMetadata', - 'CreateDatasetRequest', - 'DeleteDatasetRequest', - 'ExportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'GetAnnotationSpecRequest', - 'GetDatasetRequest', - 'ImportDataOperationMetadata', - 'ImportDataRequest', - 'ImportDataResponse', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'UpdateDatasetRequest', - 'DeployedModelRef', - 'EncryptionSpec', - 'DeployedModel', - 'Endpoint', - 'CreateEndpointOperationMetadata', - 'CreateEndpointRequest', - 'DeleteEndpointRequest', - 'DeployModelOperationMetadata', - 'DeployModelRequest', - 'DeployModelResponse', - 'GetEndpointRequest', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'UndeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UpdateEndpointRequest', - 'EnvVar', - 'HyperparameterTuningJob', - 'BigQueryDestination', - 'BigQuerySource', - 'ContainerRegistryDestination', - 'GcsDestination', - 'GcsSource', - 'CancelBatchPredictionJobRequest', - 'CancelCustomJobRequest', - 'CancelDataLabelingJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CreateBatchPredictionJobRequest', - 'CreateCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'CreateHyperparameterTuningJobRequest', - 'DeleteBatchPredictionJobRequest', - 'DeleteCustomJobRequest', - 'DeleteDataLabelingJobRequest', - 'DeleteHyperparameterTuningJobRequest', - 'GetBatchPredictionJobRequest', - 'GetCustomJobRequest', - 'GetDataLabelingJobRequest', - 'GetHyperparameterTuningJobRequest', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'JobState', - 'AutomaticResources', - 'BatchDedicatedResources', - 'DedicatedResources', - 'DiskSpec', - 'MachineSpec', - 'ResourcesConsumed', - 'ManualBatchTuningParameters', - 'MigratableResource', - 'BatchMigrateResourcesOperationMetadata', - 'BatchMigrateResourcesRequest', - 'BatchMigrateResourcesResponse', - 'MigrateResourceRequest', - 'MigrateResourceResponse', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'Model', - 'ModelContainerSpec', - 'Port', - 'PredictSchemata', - 'ModelEvaluation', - 'ModelEvaluationSlice', - 'DeleteModelRequest', - 'ExportModelOperationMetadata', - 'ExportModelRequest', - 'ExportModelResponse', - 'GetModelEvaluationRequest', - 'GetModelEvaluationSliceRequest', - 'GetModelRequest', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'ListModelsRequest', - 'ListModelsResponse', - 'UpdateModelRequest', - 'UploadModelOperationMetadata', - 'UploadModelRequest', - 'UploadModelResponse', - 'DeleteOperationMetadata', - 'GenericOperationMetadata', - 'CancelTrainingPipelineRequest', - 'CreateTrainingPipelineRequest', - 'DeleteTrainingPipelineRequest', - 'GetTrainingPipelineRequest', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'PipelineState', - 'PredictRequest', - 'PredictResponse', - 'SpecialistPool', - 'CreateSpecialistPoolOperationMetadata', - 'CreateSpecialistPoolRequest', - 'DeleteSpecialistPoolRequest', - 'GetSpecialistPoolRequest', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'UpdateSpecialistPoolOperationMetadata', - 'UpdateSpecialistPoolRequest', - 'Measurement', - 'StudySpec', - 'Trial', - 'FilterSplit', - 'FractionSplit', - 'InputDataConfig', - 'PredefinedSplit', - 'TimestampSplit', - 'TrainingPipeline', - 'UserActionReference', + "AcceleratorType", + "Annotation", + "AnnotationSpec", + "BatchPredictionJob", + "CompletionStats", + "ContainerSpec", + "CustomJob", + "CustomJobSpec", + "PythonPackageSpec", + "Scheduling", + "WorkerPoolSpec", + "DataItem", + "ActiveLearningConfig", + "DataLabelingJob", + "SampleConfig", + "TrainingConfig", + "Dataset", + "ExportDataConfig", + "ImportDataConfig", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "DeleteDatasetRequest", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "GetAnnotationSpecRequest", + "GetDatasetRequest", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "UpdateDatasetRequest", + "DeployedModelRef", + "EncryptionSpec", + "DeployedModel", + "Endpoint", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateEndpointRequest", + "EnvVar", + "HyperparameterTuningJob", + "BigQueryDestination", + "BigQuerySource", + "ContainerRegistryDestination", + "GcsDestination", + "GcsSource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteHyperparameterTuningJobRequest", + "GetBatchPredictionJobRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetHyperparameterTuningJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "JobState", + "AutomaticResources", + "BatchDedicatedResources", + "DedicatedResources", + "DiskSpec", + "MachineSpec", + "ResourcesConsumed", + "ManualBatchTuningParameters", + "MigratableResource", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceRequest", + "MigrateResourceResponse", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "Model", + "ModelContainerSpec", + "Port", + "PredictSchemata", + "ModelEvaluation", + "ModelEvaluationSlice", + "DeleteModelRequest", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "DeleteOperationMetadata", + "GenericOperationMetadata", + "CancelTrainingPipelineRequest", + "CreateTrainingPipelineRequest", + "DeleteTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "PipelineState", + "PredictRequest", + "PredictResponse", + "SpecialistPool", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "DeleteSpecialistPoolRequest", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "Measurement", + "StudySpec", + "Trial", + "FilterSplit", + "FractionSplit", + "InputDataConfig", + "PredefinedSplit", + "TimestampSplit", + "TrainingPipeline", + "UserActionReference", ) diff --git a/google/cloud/aiplatform_v1/types/accelerator_type.py b/google/cloud/aiplatform_v1/types/accelerator_type.py index b22abd8ffb..640436c38c 100644 --- a/google/cloud/aiplatform_v1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1/types/accelerator_type.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'AcceleratorType', - }, + package="google.cloud.aiplatform.v1", manifest={"AcceleratorType",}, ) diff --git a/google/cloud/aiplatform_v1/types/annotation.py b/google/cloud/aiplatform_v1/types/annotation.py index 3a08c3dead..000ca49dcb 100644 --- a/google/cloud/aiplatform_v1/types/annotation.py +++ b/google/cloud/aiplatform_v1/types/annotation.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'Annotation', - }, + package="google.cloud.aiplatform.v1", manifest={"Annotation",}, ) @@ -94,22 +91,16 @@ class Annotation(proto.Message): payload_schema_uri = proto.Field(proto.STRING, number=2) - payload = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=8) - annotation_source = proto.Field(proto.MESSAGE, number=5, - message=user_action_reference.UserActionReference, + annotation_source = proto.Field( + proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, ) labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1/types/annotation_spec.py b/google/cloud/aiplatform_v1/types/annotation_spec.py index 4bcd10d1ba..41f228ad72 100644 --- a/google/cloud/aiplatform_v1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1/types/annotation_spec.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'AnnotationSpec', - }, + package="google.cloud.aiplatform.v1", manifest={"AnnotationSpec",}, ) @@ -58,13 +55,9 @@ class AnnotationSpec(proto.Message): display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1/types/batch_prediction_job.py index a75a861570..d2d8f02203 100644 --- a/google/cloud/aiplatform_v1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1/types/batch_prediction_job.py @@ -23,17 +23,16 @@ from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import job_state from google.cloud.aiplatform_v1.types import machine_resources -from google.cloud.aiplatform_v1.types import manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters +from google.cloud.aiplatform_v1.types import ( + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, +) from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'BatchPredictionJob', - }, + package="google.cloud.aiplatform.v1", manifest={"BatchPredictionJob",}, ) @@ -148,6 +147,7 @@ class BatchPredictionJob(proto.Message): resources created by the BatchPredictionJob will be encrypted with the provided encryption key. """ + class InputConfig(proto.Message): r"""Configures the input to ``BatchPredictionJob``. @@ -174,12 +174,12 @@ class InputConfig(proto.Message): ``supported_input_storage_formats``. """ - gcs_source = proto.Field(proto.MESSAGE, number=2, oneof='source', - message=io.GcsSource, + gcs_source = proto.Field( + proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, ) - bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', - message=io.BigQuerySource, + bigquery_source = proto.Field( + proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, ) instances_format = proto.Field(proto.STRING, number=1) @@ -250,11 +250,14 @@ class OutputConfig(proto.Message): ``supported_output_storage_formats``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, ) - bigquery_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', + bigquery_destination = proto.Field( + proto.MESSAGE, + number=3, + oneof="destination", message=io.BigQueryDestination, ) @@ -275,9 +278,13 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field(proto.STRING, number=1, oneof='output_location') + gcs_output_directory = proto.Field( + proto.STRING, number=1, oneof="output_location" + ) - bigquery_output_dataset = proto.Field(proto.STRING, number=2, oneof='output_location') + bigquery_output_dataset = proto.Field( + proto.STRING, number=2, oneof="output_location" + ) name = proto.Field(proto.STRING, number=1) @@ -285,70 +292,52 @@ class OutputInfo(proto.Message): model = proto.Field(proto.STRING, number=3) - input_config = proto.Field(proto.MESSAGE, number=4, - message=InputConfig, - ) + input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) - model_parameters = proto.Field(proto.MESSAGE, number=5, - message=struct.Value, - ) + model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - output_config = proto.Field(proto.MESSAGE, number=6, - message=OutputConfig, - ) + output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) - dedicated_resources = proto.Field(proto.MESSAGE, number=7, - message=machine_resources.BatchDedicatedResources, + dedicated_resources = proto.Field( + proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, ) - manual_batch_tuning_parameters = proto.Field(proto.MESSAGE, number=8, + manual_batch_tuning_parameters = proto.Field( + proto.MESSAGE, + number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) - output_info = proto.Field(proto.MESSAGE, number=9, - message=OutputInfo, - ) + output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) - state = proto.Field(proto.ENUM, number=10, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - error = proto.Field(proto.MESSAGE, number=11, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) - partial_failures = proto.RepeatedField(proto.MESSAGE, number=12, - message=status.Status, + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=12, message=status.Status, ) - resources_consumed = proto.Field(proto.MESSAGE, number=13, - message=machine_resources.ResourcesConsumed, + resources_consumed = proto.Field( + proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, ) - completion_stats = proto.Field(proto.MESSAGE, number=14, - message=gca_completion_stats.CompletionStats, + completion_stats = proto.Field( + proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, ) - create_time = proto.Field(proto.MESSAGE, number=15, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=16, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=17, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=18, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=19) - encryption_spec = proto.Field(proto.MESSAGE, number=24, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1/types/completion_stats.py b/google/cloud/aiplatform_v1/types/completion_stats.py index 8a0f151024..05648d82c4 100644 --- a/google/cloud/aiplatform_v1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1/types/completion_stats.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'CompletionStats', - }, + package="google.cloud.aiplatform.v1", manifest={"CompletionStats",}, ) diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py index 176e077042..c97cba6d82 100644 --- a/google/cloud/aiplatform_v1/types/custom_job.py +++ b/google/cloud/aiplatform_v1/types/custom_job.py @@ -29,14 +29,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'CustomJob', - 'CustomJobSpec', - 'WorkerPoolSpec', - 'ContainerSpec', - 'PythonPackageSpec', - 'Scheduling', + "CustomJob", + "CustomJobSpec", + "WorkerPoolSpec", + "ContainerSpec", + "PythonPackageSpec", + "Scheduling", }, ) @@ -96,38 +96,24 @@ class CustomJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - job_spec = proto.Field(proto.MESSAGE, number=4, - message='CustomJobSpec', - ) + job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) - state = proto.Field(proto.ENUM, number=5, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=10, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=11) - encryption_spec = proto.Field(proto.MESSAGE, number=12, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=12, message=gca_encryption_spec.EncryptionSpec, ) @@ -190,20 +176,18 @@ class CustomJobSpec(proto.Message): ``//logs/`` """ - worker_pool_specs = proto.RepeatedField(proto.MESSAGE, number=1, - message='WorkerPoolSpec', + worker_pool_specs = proto.RepeatedField( + proto.MESSAGE, number=1, message="WorkerPoolSpec", ) - scheduling = proto.Field(proto.MESSAGE, number=3, - message='Scheduling', - ) + scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) service_account = proto.Field(proto.STRING, number=4) network = proto.Field(proto.STRING, number=5) - base_output_directory = proto.Field(proto.MESSAGE, number=6, - message=io.GcsDestination, + base_output_directory = proto.Field( + proto.MESSAGE, number=6, message=io.GcsDestination, ) @@ -225,22 +209,22 @@ class WorkerPoolSpec(proto.Message): Disk spec. """ - container_spec = proto.Field(proto.MESSAGE, number=6, oneof='task', - message='ContainerSpec', + container_spec = proto.Field( + proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", ) - python_package_spec = proto.Field(proto.MESSAGE, number=7, oneof='task', - message='PythonPackageSpec', + python_package_spec = proto.Field( + proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", ) - machine_spec = proto.Field(proto.MESSAGE, number=1, - message=machine_resources.MachineSpec, + machine_spec = proto.Field( + proto.MESSAGE, number=1, message=machine_resources.MachineSpec, ) replica_count = proto.Field(proto.INT64, number=2) - disk_spec = proto.Field(proto.MESSAGE, number=5, - message=machine_resources.DiskSpec, + disk_spec = proto.Field( + proto.MESSAGE, number=5, message=machine_resources.DiskSpec, ) @@ -270,9 +254,7 @@ class ContainerSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, - message=env_var.EnvVar, - ) + env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) class PythonPackageSpec(proto.Message): @@ -310,9 +292,7 @@ class PythonPackageSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=4) - env = proto.RepeatedField(proto.MESSAGE, number=5, - message=env_var.EnvVar, - ) + env = proto.RepeatedField(proto.MESSAGE, number=5, message=env_var.EnvVar,) class Scheduling(proto.Message): @@ -330,9 +310,7 @@ class Scheduling(proto.Message): to workers leaving and joining a job. """ - timeout = proto.Field(proto.MESSAGE, number=1, - message=duration.Duration, - ) + timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1/types/data_item.py b/google/cloud/aiplatform_v1/types/data_item.py index d29e056d16..20ff14a0d8 100644 --- a/google/cloud/aiplatform_v1/types/data_item.py +++ b/google/cloud/aiplatform_v1/types/data_item.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'DataItem', - }, + package="google.cloud.aiplatform.v1", manifest={"DataItem",}, ) @@ -73,19 +70,13 @@ class DataItem(proto.Message): name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=2, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=3) - payload = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1/types/data_labeling_job.py b/google/cloud/aiplatform_v1/types/data_labeling_job.py index 8caca23d09..e1058737bf 100644 --- a/google/cloud/aiplatform_v1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1/types/data_labeling_job.py @@ -27,12 +27,12 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'DataLabelingJob', - 'ActiveLearningConfig', - 'SampleConfig', - 'TrainingConfig', + "DataLabelingJob", + "ActiveLearningConfig", + "SampleConfig", + "TrainingConfig", }, ) @@ -154,42 +154,30 @@ class DataLabelingJob(proto.Message): inputs_schema_uri = proto.Field(proto.STRING, number=6) - inputs = proto.Field(proto.MESSAGE, number=7, - message=struct.Value, - ) + inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) - state = proto.Field(proto.ENUM, number=8, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) labeling_progress = proto.Field(proto.INT32, number=13) - current_spend = proto.Field(proto.MESSAGE, number=14, - message=money.Money, - ) + current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) - create_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=10, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=22, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=22, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=11) specialist_pools = proto.RepeatedField(proto.STRING, number=16) - encryption_spec = proto.Field(proto.MESSAGE, number=20, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=20, message=gca_encryption_spec.EncryptionSpec, ) - active_learning_config = proto.Field(proto.MESSAGE, number=21, - message='ActiveLearningConfig', + active_learning_config = proto.Field( + proto.MESSAGE, number=21, message="ActiveLearningConfig", ) @@ -218,18 +206,18 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field(proto.INT64, number=1, oneof='human_labeling_budget') - - max_data_item_percentage = proto.Field(proto.INT32, number=2, oneof='human_labeling_budget') - - sample_config = proto.Field(proto.MESSAGE, number=3, - message='SampleConfig', + max_data_item_count = proto.Field( + proto.INT64, number=1, oneof="human_labeling_budget" ) - training_config = proto.Field(proto.MESSAGE, number=4, - message='TrainingConfig', + max_data_item_percentage = proto.Field( + proto.INT32, number=2, oneof="human_labeling_budget" ) + sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) + + training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) + class SampleConfig(proto.Message): r"""Active learning data sampling config. For every active @@ -249,6 +237,7 @@ class SampleConfig(proto.Message): strategy will decide which data should be selected for human labeling in every batch. """ + class SampleStrategy(proto.Enum): r"""Sample strategy decides which subset of DataItems should be selected for human labeling in every batch. @@ -256,14 +245,16 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field(proto.INT32, number=1, oneof='initial_batch_sample_size') - - following_batch_sample_percentage = proto.Field(proto.INT32, number=3, oneof='following_batch_sample_size') + initial_batch_sample_percentage = proto.Field( + proto.INT32, number=1, oneof="initial_batch_sample_size" + ) - sample_strategy = proto.Field(proto.ENUM, number=5, - enum=SampleStrategy, + following_batch_sample_percentage = proto.Field( + proto.INT32, number=3, oneof="following_batch_sample_size" ) + sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) + class TrainingConfig(proto.Message): r"""CMLE training config. For every active learning labeling diff --git a/google/cloud/aiplatform_v1/types/dataset.py b/google/cloud/aiplatform_v1/types/dataset.py index 29e205f9c4..2f75dce0d5 100644 --- a/google/cloud/aiplatform_v1/types/dataset.py +++ b/google/cloud/aiplatform_v1/types/dataset.py @@ -25,12 +25,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'Dataset', - 'ImportDataConfig', - 'ExportDataConfig', - }, + package="google.cloud.aiplatform.v1", + manifest={"Dataset", "ImportDataConfig", "ExportDataConfig",}, ) @@ -98,24 +94,18 @@ class Dataset(proto.Message): metadata_schema_uri = proto.Field(proto.STRING, number=3) - metadata = proto.Field(proto.MESSAGE, number=8, - message=struct.Value, - ) + metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=6) labels = proto.MapField(proto.STRING, proto.STRING, number=7) - encryption_spec = proto.Field(proto.MESSAGE, number=11, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=11, message=gca_encryption_spec.EncryptionSpec, ) @@ -151,8 +141,8 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field(proto.MESSAGE, number=1, oneof='source', - message=io.GcsSource, + gcs_source = proto.Field( + proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, ) data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) @@ -185,8 +175,8 @@ class ExportDataConfig(proto.Message): ``ListAnnotations``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, ) annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/dataset_service.py b/google/cloud/aiplatform_v1/types/dataset_service.py index 1991dd02ec..ccc8cce600 100644 --- a/google/cloud/aiplatform_v1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1/types/dataset_service.py @@ -26,26 +26,26 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'CreateDatasetRequest', - 'CreateDatasetOperationMetadata', - 'GetDatasetRequest', - 'UpdateDatasetRequest', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'DeleteDatasetRequest', - 'ImportDataRequest', - 'ImportDataResponse', - 'ImportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportDataOperationMetadata', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'GetAnnotationSpecRequest', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', + "CreateDatasetRequest", + "CreateDatasetOperationMetadata", + "GetDatasetRequest", + "UpdateDatasetRequest", + "ListDatasetsRequest", + "ListDatasetsResponse", + "DeleteDatasetRequest", + "ImportDataRequest", + "ImportDataResponse", + "ImportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportDataOperationMetadata", + "ListDataItemsRequest", + "ListDataItemsResponse", + "GetAnnotationSpecRequest", + "ListAnnotationsRequest", + "ListAnnotationsResponse", }, ) @@ -65,9 +65,7 @@ class CreateDatasetRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - dataset = proto.Field(proto.MESSAGE, number=2, - message=gca_dataset.Dataset, - ) + dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) class CreateDatasetOperationMetadata(proto.Message): @@ -79,8 +77,8 @@ class CreateDatasetOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -97,9 +95,7 @@ class GetDatasetRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateDatasetRequest(proto.Message): @@ -121,13 +117,9 @@ class UpdateDatasetRequest(proto.Message): - ``labels`` """ - dataset = proto.Field(proto.MESSAGE, number=1, - message=gca_dataset.Dataset, - ) + dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class ListDatasetsRequest(proto.Message): @@ -179,9 +171,7 @@ class ListDatasetsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -202,8 +192,8 @@ class ListDatasetsResponse(proto.Message): def raw_page(self): return self - datasets = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_dataset.Dataset, + datasets = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_dataset.Dataset, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -239,8 +229,8 @@ class ImportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - import_configs = proto.RepeatedField(proto.MESSAGE, number=2, - message=gca_dataset.ImportDataConfig, + import_configs = proto.RepeatedField( + proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, ) @@ -259,8 +249,8 @@ class ImportDataOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -278,8 +268,8 @@ class ExportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - export_config = proto.Field(proto.MESSAGE, number=2, - message=gca_dataset.ExportDataConfig, + export_config = proto.Field( + proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, ) @@ -309,8 +299,8 @@ class ExportDataOperationMetadata(proto.Message): the directory. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -347,9 +337,7 @@ class ListDataItemsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -370,8 +358,8 @@ class ListDataItemsResponse(proto.Message): def raw_page(self): return self - data_items = proto.RepeatedField(proto.MESSAGE, number=1, - message=data_item.DataItem, + data_items = proto.RepeatedField( + proto.MESSAGE, number=1, message=data_item.DataItem, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -392,9 +380,7 @@ class GetAnnotationSpecRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class ListAnnotationsRequest(proto.Message): @@ -429,9 +415,7 @@ class ListAnnotationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -452,8 +436,8 @@ class ListAnnotationsResponse(proto.Message): def raw_page(self): return self - annotations = proto.RepeatedField(proto.MESSAGE, number=1, - message=annotation.Annotation, + annotations = proto.RepeatedField( + proto.MESSAGE, number=1, message=annotation.Annotation, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1/types/deployed_model_ref.py index ffd0e4182d..2d53610ed5 100644 --- a/google/cloud/aiplatform_v1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1/types/deployed_model_ref.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'DeployedModelRef', - }, + package="google.cloud.aiplatform.v1", manifest={"DeployedModelRef",}, ) diff --git a/google/cloud/aiplatform_v1/types/encryption_spec.py b/google/cloud/aiplatform_v1/types/encryption_spec.py index a87a91a91e..ae908d4b72 100644 --- a/google/cloud/aiplatform_v1/types/encryption_spec.py +++ b/google/cloud/aiplatform_v1/types/encryption_spec.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'EncryptionSpec', - }, + package="google.cloud.aiplatform.v1", manifest={"EncryptionSpec",}, ) diff --git a/google/cloud/aiplatform_v1/types/endpoint.py b/google/cloud/aiplatform_v1/types/endpoint.py index cff9c6b4a7..5cbe3c1b1d 100644 --- a/google/cloud/aiplatform_v1/types/endpoint.py +++ b/google/cloud/aiplatform_v1/types/endpoint.py @@ -24,11 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'Endpoint', - 'DeployedModel', - }, + package="google.cloud.aiplatform.v1", manifest={"Endpoint", "DeployedModel",}, ) @@ -96,8 +92,8 @@ class Endpoint(proto.Message): description = proto.Field(proto.STRING, number=3) - deployed_models = proto.RepeatedField(proto.MESSAGE, number=4, - message='DeployedModel', + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=4, message="DeployedModel", ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) @@ -106,16 +102,12 @@ class Endpoint(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=7) - create_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - encryption_spec = proto.Field(proto.MESSAGE, number=10, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=10, message=gca_encryption_spec.EncryptionSpec, ) @@ -176,11 +168,17 @@ class DeployedModel(proto.Message): option. """ - dedicated_resources = proto.Field(proto.MESSAGE, number=7, oneof='prediction_resources', + dedicated_resources = proto.Field( + proto.MESSAGE, + number=7, + oneof="prediction_resources", message=machine_resources.DedicatedResources, ) - automatic_resources = proto.Field(proto.MESSAGE, number=8, oneof='prediction_resources', + automatic_resources = proto.Field( + proto.MESSAGE, + number=8, + oneof="prediction_resources", message=machine_resources.AutomaticResources, ) @@ -190,9 +188,7 @@ class DeployedModel(proto.Message): display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) service_account = proto.Field(proto.STRING, number=11) diff --git a/google/cloud/aiplatform_v1/types/endpoint_service.py b/google/cloud/aiplatform_v1/types/endpoint_service.py index 67b893b9aa..24e00bd486 100644 --- a/google/cloud/aiplatform_v1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1/types/endpoint_service.py @@ -24,21 +24,21 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'CreateEndpointRequest', - 'CreateEndpointOperationMetadata', - 'GetEndpointRequest', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'UpdateEndpointRequest', - 'DeleteEndpointRequest', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UndeployModelOperationMetadata', + "CreateEndpointRequest", + "CreateEndpointOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UpdateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelRequest", + "DeployModelResponse", + "DeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UndeployModelOperationMetadata", }, ) @@ -58,9 +58,7 @@ class CreateEndpointRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - endpoint = proto.Field(proto.MESSAGE, number=2, - message=gca_endpoint.Endpoint, - ) + endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) class CreateEndpointOperationMetadata(proto.Message): @@ -72,8 +70,8 @@ class CreateEndpointOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -153,9 +151,7 @@ class ListEndpointsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -177,8 +173,8 @@ class ListEndpointsResponse(proto.Message): def raw_page(self): return self - endpoints = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_endpoint.Endpoint, + endpoints = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -197,13 +193,9 @@ class UpdateEndpointRequest(proto.Message): `FieldMask `__. """ - endpoint = proto.Field(proto.MESSAGE, number=1, - message=gca_endpoint.Endpoint, - ) + endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteEndpointRequest(proto.Message): @@ -256,8 +248,8 @@ class DeployModelRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - deployed_model = proto.Field(proto.MESSAGE, number=2, - message=gca_endpoint.DeployedModel, + deployed_model = proto.Field( + proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -273,8 +265,8 @@ class DeployModelResponse(proto.Message): the Endpoint. """ - deployed_model = proto.Field(proto.MESSAGE, number=1, - message=gca_endpoint.DeployedModel, + deployed_model = proto.Field( + proto.MESSAGE, number=1, message=gca_endpoint.DeployedModel, ) @@ -287,8 +279,8 @@ class DeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -337,8 +329,8 @@ class UndeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1/types/env_var.py b/google/cloud/aiplatform_v1/types/env_var.py index 8a843cd18c..f456c15808 100644 --- a/google/cloud/aiplatform_v1/types/env_var.py +++ b/google/cloud/aiplatform_v1/types/env_var.py @@ -18,12 +18,7 @@ import proto # type: ignore -__protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'EnvVar', - }, -) +__protobuf__ = proto.module(package="google.cloud.aiplatform.v1", manifest={"EnvVar",},) class EnvVar(proto.Message): diff --git a/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py index e19c94b054..63290ff9b4 100644 --- a/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py @@ -27,10 +27,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'HyperparameterTuningJob', - }, + package="google.cloud.aiplatform.v1", manifest={"HyperparameterTuningJob",}, ) @@ -109,9 +106,7 @@ class HyperparameterTuningJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=4, - message=study.StudySpec, - ) + study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) max_trial_count = proto.Field(proto.INT32, number=5) @@ -119,42 +114,28 @@ class HyperparameterTuningJob(proto.Message): max_failed_trial_count = proto.Field(proto.INT32, number=7) - trial_job_spec = proto.Field(proto.MESSAGE, number=8, - message=custom_job.CustomJobSpec, + trial_job_spec = proto.Field( + proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, ) - trials = proto.RepeatedField(proto.MESSAGE, number=9, - message=study.Trial, - ) + trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) - state = proto.Field(proto.ENUM, number=10, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=15, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=16) - encryption_spec = proto.Field(proto.MESSAGE, number=17, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=17, message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1/types/io.py b/google/cloud/aiplatform_v1/types/io.py index 2cf3c7b5f6..1a75ea33bc 100644 --- a/google/cloud/aiplatform_v1/types/io.py +++ b/google/cloud/aiplatform_v1/types/io.py @@ -19,13 +19,13 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'GcsSource', - 'GcsDestination', - 'BigQuerySource', - 'BigQueryDestination', - 'ContainerRegistryDestination', + "GcsSource", + "GcsDestination", + "BigQuerySource", + "BigQueryDestination", + "ContainerRegistryDestination", }, ) diff --git a/google/cloud/aiplatform_v1/types/job_service.py b/google/cloud/aiplatform_v1/types/job_service.py index edf28bd54b..3a6d844ea7 100644 --- a/google/cloud/aiplatform_v1/types/job_service.py +++ b/google/cloud/aiplatform_v1/types/job_service.py @@ -18,40 +18,44 @@ import proto # type: ignore -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'CreateCustomJobRequest', - 'GetCustomJobRequest', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'DeleteCustomJobRequest', - 'CancelCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'GetDataLabelingJobRequest', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'DeleteDataLabelingJobRequest', - 'CancelDataLabelingJobRequest', - 'CreateHyperparameterTuningJobRequest', - 'GetHyperparameterTuningJobRequest', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'DeleteHyperparameterTuningJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CreateBatchPredictionJobRequest', - 'GetBatchPredictionJobRequest', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'DeleteBatchPredictionJobRequest', - 'CancelBatchPredictionJobRequest', + "CreateCustomJobRequest", + "GetCustomJobRequest", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "DeleteCustomJobRequest", + "CancelCustomJobRequest", + "CreateDataLabelingJobRequest", + "GetDataLabelingJobRequest", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "DeleteDataLabelingJobRequest", + "CancelDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "GetHyperparameterTuningJobRequest", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "DeleteHyperparameterTuningJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "GetBatchPredictionJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "DeleteBatchPredictionJobRequest", + "CancelBatchPredictionJobRequest", }, ) @@ -71,9 +75,7 @@ class CreateCustomJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - custom_job = proto.Field(proto.MESSAGE, number=2, - message=gca_custom_job.CustomJob, - ) + custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) class GetCustomJobRequest(proto.Message): @@ -136,9 +138,7 @@ class ListCustomJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListCustomJobsResponse(proto.Message): @@ -158,8 +158,8 @@ class ListCustomJobsResponse(proto.Message): def raw_page(self): return self - custom_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_custom_job.CustomJob, + custom_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -206,8 +206,8 @@ class CreateDataLabelingJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field(proto.MESSAGE, number=2, - message=gca_data_labeling_job.DataLabelingJob, + data_labeling_job = proto.Field( + proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, ) @@ -273,9 +273,7 @@ class ListDataLabelingJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -296,8 +294,8 @@ class ListDataLabelingJobsResponse(proto.Message): def raw_page(self): return self - data_labeling_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_data_labeling_job.DataLabelingJob, + data_labeling_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -348,7 +346,9 @@ class CreateHyperparameterTuningJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - hyperparameter_tuning_job = proto.Field(proto.MESSAGE, number=2, + hyperparameter_tuning_job = proto.Field( + proto.MESSAGE, + number=2, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -415,9 +415,7 @@ class ListHyperparameterTuningJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListHyperparameterTuningJobsResponse(proto.Message): @@ -439,7 +437,9 @@ class ListHyperparameterTuningJobsResponse(proto.Message): def raw_page(self): return self - hyperparameter_tuning_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + hyperparameter_tuning_jobs = proto.RepeatedField( + proto.MESSAGE, + number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -491,8 +491,8 @@ class CreateBatchPredictionJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - batch_prediction_job = proto.Field(proto.MESSAGE, number=2, - message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_job = proto.Field( + proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -558,9 +558,7 @@ class ListBatchPredictionJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListBatchPredictionJobsResponse(proto.Message): @@ -581,8 +579,8 @@ class ListBatchPredictionJobsResponse(proto.Message): def raw_page(self): return self - batch_prediction_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/job_state.py b/google/cloud/aiplatform_v1/types/job_state.py index 5ca5147c2c..40b1694f86 100644 --- a/google/cloud/aiplatform_v1/types/job_state.py +++ b/google/cloud/aiplatform_v1/types/job_state.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'JobState', - }, + package="google.cloud.aiplatform.v1", manifest={"JobState",}, ) diff --git a/google/cloud/aiplatform_v1/types/machine_resources.py b/google/cloud/aiplatform_v1/types/machine_resources.py index a5e8209b0f..f6864eb798 100644 --- a/google/cloud/aiplatform_v1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1/types/machine_resources.py @@ -22,14 +22,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'MachineSpec', - 'DedicatedResources', - 'AutomaticResources', - 'BatchDedicatedResources', - 'ResourcesConsumed', - 'DiskSpec', + "MachineSpec", + "DedicatedResources", + "AutomaticResources", + "BatchDedicatedResources", + "ResourcesConsumed", + "DiskSpec", }, ) @@ -64,8 +64,8 @@ class MachineSpec(proto.Message): machine_type = proto.Field(proto.STRING, number=1) - accelerator_type = proto.Field(proto.ENUM, number=2, - enum=gca_accelerator_type.AcceleratorType, + accelerator_type = proto.Field( + proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, ) accelerator_count = proto.Field(proto.INT32, number=3) @@ -104,9 +104,7 @@ class DedicatedResources(proto.Message): as the default value. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, - message='MachineSpec', - ) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) min_replica_count = proto.Field(proto.INT32, number=2) @@ -170,9 +168,7 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, - message='MachineSpec', - ) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) starting_replica_count = proto.Field(proto.INT32, number=2) diff --git a/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py b/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py index 07abcc8f01..7500d618a0 100644 --- a/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py +++ b/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'ManualBatchTuningParameters', - }, + package="google.cloud.aiplatform.v1", manifest={"ManualBatchTuningParameters",}, ) diff --git a/google/cloud/aiplatform_v1/types/migratable_resource.py b/google/cloud/aiplatform_v1/types/migratable_resource.py index 0b73b10a22..652a835c89 100644 --- a/google/cloud/aiplatform_v1/types/migratable_resource.py +++ b/google/cloud/aiplatform_v1/types/migratable_resource.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'MigratableResource', - }, + package="google.cloud.aiplatform.v1", manifest={"MigratableResource",}, ) @@ -55,6 +52,7 @@ class MigratableResource(proto.Message): Output only. Timestamp when this MigratableResource was last updated. """ + class MlEngineModelVersion(proto.Message): r"""Represents one model Version in ml.googleapis.com. @@ -123,6 +121,7 @@ class DataLabelingDataset(proto.Message): datalabeling.googleapis.com belongs to the data labeling Dataset. """ + class DataLabelingAnnotatedDataset(proto.Message): r"""Represents one AnnotatedDataset in datalabeling.googleapis.com. @@ -146,32 +145,34 @@ class DataLabelingAnnotatedDataset(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=4) - data_labeling_annotated_datasets = proto.RepeatedField(proto.MESSAGE, number=3, - message='MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset', + data_labeling_annotated_datasets = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset", ) - ml_engine_model_version = proto.Field(proto.MESSAGE, number=1, oneof='resource', - message=MlEngineModelVersion, + ml_engine_model_version = proto.Field( + proto.MESSAGE, number=1, oneof="resource", message=MlEngineModelVersion, ) - automl_model = proto.Field(proto.MESSAGE, number=2, oneof='resource', - message=AutomlModel, + automl_model = proto.Field( + proto.MESSAGE, number=2, oneof="resource", message=AutomlModel, ) - automl_dataset = proto.Field(proto.MESSAGE, number=3, oneof='resource', - message=AutomlDataset, + automl_dataset = proto.Field( + proto.MESSAGE, number=3, oneof="resource", message=AutomlDataset, ) - data_labeling_dataset = proto.Field(proto.MESSAGE, number=4, oneof='resource', - message=DataLabelingDataset, + data_labeling_dataset = proto.Field( + proto.MESSAGE, number=4, oneof="resource", message=DataLabelingDataset, ) - last_migrate_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, + last_migrate_time = proto.Field( + proto.MESSAGE, number=5, message=timestamp.Timestamp, ) - last_update_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, + last_update_time = proto.Field( + proto.MESSAGE, number=6, message=timestamp.Timestamp, ) diff --git a/google/cloud/aiplatform_v1/types/migration_service.py b/google/cloud/aiplatform_v1/types/migration_service.py index d608620577..acd69b37b4 100644 --- a/google/cloud/aiplatform_v1/types/migration_service.py +++ b/google/cloud/aiplatform_v1/types/migration_service.py @@ -18,21 +18,23 @@ import proto # type: ignore -from google.cloud.aiplatform_v1.types import migratable_resource as gca_migratable_resource +from google.cloud.aiplatform_v1.types import ( + migratable_resource as gca_migratable_resource, +) from google.cloud.aiplatform_v1.types import operation from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'BatchMigrateResourcesRequest', - 'MigrateResourceRequest', - 'BatchMigrateResourcesResponse', - 'MigrateResourceResponse', - 'BatchMigrateResourcesOperationMetadata', + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", }, ) @@ -99,8 +101,8 @@ class SearchMigratableResourcesResponse(proto.Message): def raw_page(self): return self - migratable_resources = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_migratable_resource.MigratableResource, + migratable_resources = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_migratable_resource.MigratableResource, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -123,8 +125,8 @@ class BatchMigrateResourcesRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - migrate_resource_requests = proto.RepeatedField(proto.MESSAGE, number=2, - message='MigrateResourceRequest', + migrate_resource_requests = proto.RepeatedField( + proto.MESSAGE, number=2, message="MigrateResourceRequest", ) @@ -148,6 +150,7 @@ class MigrateResourceRequest(proto.Message): datalabeling.googleapis.com to AI Platform's Dataset. """ + class MigrateMlEngineModelVersionConfig(proto.Message): r"""Config for migrating version in ml.googleapis.com to AI Platform's Model. @@ -235,6 +238,7 @@ class MigrateDataLabelingDatasetConfig(proto.Message): AnnotatedDatasets have to belong to the datalabeling Dataset. """ + class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): r"""Config for migrating AnnotatedDataset in datalabeling.googleapis.com to AI Platform's SavedQuery. @@ -253,23 +257,31 @@ class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=2) - migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField(proto.MESSAGE, number=3, - message='MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig', + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig", ) - migrate_ml_engine_model_version_config = proto.Field(proto.MESSAGE, number=1, oneof='request', + migrate_ml_engine_model_version_config = proto.Field( + proto.MESSAGE, + number=1, + oneof="request", message=MigrateMlEngineModelVersionConfig, ) - migrate_automl_model_config = proto.Field(proto.MESSAGE, number=2, oneof='request', - message=MigrateAutomlModelConfig, + migrate_automl_model_config = proto.Field( + proto.MESSAGE, number=2, oneof="request", message=MigrateAutomlModelConfig, ) - migrate_automl_dataset_config = proto.Field(proto.MESSAGE, number=3, oneof='request', - message=MigrateAutomlDatasetConfig, + migrate_automl_dataset_config = proto.Field( + proto.MESSAGE, number=3, oneof="request", message=MigrateAutomlDatasetConfig, ) - migrate_data_labeling_dataset_config = proto.Field(proto.MESSAGE, number=4, oneof='request', + migrate_data_labeling_dataset_config = proto.Field( + proto.MESSAGE, + number=4, + oneof="request", message=MigrateDataLabelingDatasetConfig, ) @@ -283,8 +295,8 @@ class BatchMigrateResourcesResponse(proto.Message): Successfully migrated resources. """ - migrate_resource_responses = proto.RepeatedField(proto.MESSAGE, number=1, - message='MigrateResourceResponse', + migrate_resource_responses = proto.RepeatedField( + proto.MESSAGE, number=1, message="MigrateResourceResponse", ) @@ -302,12 +314,12 @@ class MigrateResourceResponse(proto.Message): datalabeling.googleapis.com. """ - dataset = proto.Field(proto.STRING, number=1, oneof='migrated_resource') + dataset = proto.Field(proto.STRING, number=1, oneof="migrated_resource") - model = proto.Field(proto.STRING, number=2, oneof='migrated_resource') + model = proto.Field(proto.STRING, number=2, oneof="migrated_resource") - migratable_resource = proto.Field(proto.MESSAGE, number=3, - message=gca_migratable_resource.MigratableResource, + migratable_resource = proto.Field( + proto.MESSAGE, number=3, message=gca_migratable_resource.MigratableResource, ) @@ -322,6 +334,7 @@ class BatchMigrateResourcesOperationMetadata(proto.Message): Partial results that reflect the latest migration operation progress. """ + class PartialResult(proto.Message): r"""Represents a partial result in batch migration operation for one ``MigrateResourceRequest``. @@ -339,24 +352,24 @@ class PartialResult(proto.Message): [MigrateResourceRequest.migrate_resource_requests][]. """ - error = proto.Field(proto.MESSAGE, number=2, oneof='result', - message=status.Status, + error = proto.Field( + proto.MESSAGE, number=2, oneof="result", message=status.Status, ) - model = proto.Field(proto.STRING, number=3, oneof='result') + model = proto.Field(proto.STRING, number=3, oneof="result") - dataset = proto.Field(proto.STRING, number=4, oneof='result') + dataset = proto.Field(proto.STRING, number=4, oneof="result") - request = proto.Field(proto.MESSAGE, number=1, - message='MigrateResourceRequest', + request = proto.Field( + proto.MESSAGE, number=1, message="MigrateResourceRequest", ) - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - partial_results = proto.RepeatedField(proto.MESSAGE, number=2, - message=PartialResult, + partial_results = proto.RepeatedField( + proto.MESSAGE, number=2, message=PartialResult, ) diff --git a/google/cloud/aiplatform_v1/types/model.py b/google/cloud/aiplatform_v1/types/model.py index b830ba86da..c2db797b98 100644 --- a/google/cloud/aiplatform_v1/types/model.py +++ b/google/cloud/aiplatform_v1/types/model.py @@ -26,13 +26,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'Model', - 'PredictSchemata', - 'ModelContainerSpec', - 'Port', - }, + package="google.cloud.aiplatform.v1", + manifest={"Model", "PredictSchemata", "ModelContainerSpec", "Port",}, ) @@ -218,6 +213,7 @@ class Model(proto.Message): Model. If set, this Model and all sub-resources of this Model will be secured by this key. """ + class DeploymentResourcesType(proto.Enum): r"""Identifies a type of Model's prediction resources.""" DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = 0 @@ -254,6 +250,7 @@ class ExportFormat(proto.Message): Output only. The content of this Model that may be exported. """ + class ExportableContent(proto.Enum): r"""The Model content that can be exported.""" EXPORTABLE_CONTENT_UNSPECIFIED = 0 @@ -262,8 +259,8 @@ class ExportableContent(proto.Enum): id = proto.Field(proto.STRING, number=1) - exportable_contents = proto.RepeatedField(proto.ENUM, number=2, - enum='Model.ExportFormat.ExportableContent', + exportable_contents = proto.RepeatedField( + proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", ) name = proto.Field(proto.STRING, number=1) @@ -272,54 +269,44 @@ class ExportableContent(proto.Enum): description = proto.Field(proto.STRING, number=3) - predict_schemata = proto.Field(proto.MESSAGE, number=4, - message='PredictSchemata', - ) + predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) metadata_schema_uri = proto.Field(proto.STRING, number=5) - metadata = proto.Field(proto.MESSAGE, number=6, - message=struct.Value, - ) + metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - supported_export_formats = proto.RepeatedField(proto.MESSAGE, number=20, - message=ExportFormat, + supported_export_formats = proto.RepeatedField( + proto.MESSAGE, number=20, message=ExportFormat, ) training_pipeline = proto.Field(proto.STRING, number=7) - container_spec = proto.Field(proto.MESSAGE, number=9, - message='ModelContainerSpec', - ) + container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) artifact_uri = proto.Field(proto.STRING, number=26) - supported_deployment_resources_types = proto.RepeatedField(proto.ENUM, number=10, - enum=DeploymentResourcesType, + supported_deployment_resources_types = proto.RepeatedField( + proto.ENUM, number=10, enum=DeploymentResourcesType, ) supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) - create_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - deployed_models = proto.RepeatedField(proto.MESSAGE, number=15, - message=deployed_model_ref.DeployedModelRef, + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, ) etag = proto.Field(proto.STRING, number=16) labels = proto.MapField(proto.STRING, proto.STRING, number=17) - encryption_spec = proto.Field(proto.MESSAGE, number=24, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, ) @@ -618,13 +605,9 @@ class ModelContainerSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, - message=env_var.EnvVar, - ) + env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) - ports = proto.RepeatedField(proto.MESSAGE, number=5, - message='Port', - ) + ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) predict_route = proto.Field(proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1/types/model_evaluation.py b/google/cloud/aiplatform_v1/types/model_evaluation.py index 08bafad024..f617f3d197 100644 --- a/google/cloud/aiplatform_v1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1/types/model_evaluation.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'ModelEvaluation', - }, + package="google.cloud.aiplatform.v1", manifest={"ModelEvaluation",}, ) @@ -66,13 +63,9 @@ class ModelEvaluation(proto.Message): metrics_schema_uri = proto.Field(proto.STRING, number=2) - metrics = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) slice_dimensions = proto.RepeatedField(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1/types/model_evaluation_slice.py index 2b6065593c..5653c3d2b6 100644 --- a/google/cloud/aiplatform_v1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1/types/model_evaluation_slice.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'ModelEvaluationSlice', - }, + package="google.cloud.aiplatform.v1", manifest={"ModelEvaluationSlice",}, ) @@ -57,6 +54,7 @@ class ModelEvaluationSlice(proto.Message): Output only. Timestamp when this ModelEvaluationSlice was created. """ + class Slice(proto.Message): r"""Definition of a slice. @@ -81,19 +79,13 @@ class Slice(proto.Message): name = proto.Field(proto.STRING, number=1) - slice_ = proto.Field(proto.MESSAGE, number=2, - message=Slice, - ) + slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) metrics_schema_uri = proto.Field(proto.STRING, number=3) - metrics = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/model_service.py b/google/cloud/aiplatform_v1/types/model_service.py index e3053327c4..454e014fd5 100644 --- a/google/cloud/aiplatform_v1/types/model_service.py +++ b/google/cloud/aiplatform_v1/types/model_service.py @@ -27,25 +27,25 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'UploadModelRequest', - 'UploadModelOperationMetadata', - 'UploadModelResponse', - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'UpdateModelRequest', - 'DeleteModelRequest', - 'ExportModelRequest', - 'ExportModelOperationMetadata', - 'ExportModelResponse', - 'GetModelEvaluationRequest', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'GetModelEvaluationSliceRequest', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', + "UploadModelRequest", + "UploadModelOperationMetadata", + "UploadModelResponse", + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "DeleteModelRequest", + "ExportModelRequest", + "ExportModelOperationMetadata", + "ExportModelResponse", + "GetModelEvaluationRequest", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "GetModelEvaluationSliceRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", }, ) @@ -65,9 +65,7 @@ class UploadModelRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.MESSAGE, number=2, - message=gca_model.Model, - ) + model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) class UploadModelOperationMetadata(proto.Message): @@ -80,8 +78,8 @@ class UploadModelOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -171,9 +169,7 @@ class ListModelsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -195,9 +191,7 @@ class ListModelsResponse(proto.Message): def raw_page(self): return self - models = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_model.Model, - ) + models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) next_page_token = proto.Field(proto.STRING, number=2) @@ -216,13 +210,9 @@ class UpdateModelRequest(proto.Message): `FieldMask `__. """ - model = proto.Field(proto.MESSAGE, number=1, - message=gca_model.Model, - ) + model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteModelRequest(proto.Message): @@ -251,6 +241,7 @@ class ExportModelRequest(proto.Message): Required. The desired output location and configuration. """ + class OutputConfig(proto.Message): r"""Output configuration for the Model export. @@ -282,19 +273,17 @@ class OutputConfig(proto.Message): export_format_id = proto.Field(proto.STRING, number=1) - artifact_destination = proto.Field(proto.MESSAGE, number=3, - message=io.GcsDestination, + artifact_destination = proto.Field( + proto.MESSAGE, number=3, message=io.GcsDestination, ) - image_destination = proto.Field(proto.MESSAGE, number=4, - message=io.ContainerRegistryDestination, + image_destination = proto.Field( + proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) - output_config = proto.Field(proto.MESSAGE, number=2, - message=OutputConfig, - ) + output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) class ExportModelOperationMetadata(proto.Message): @@ -309,6 +298,7 @@ class ExportModelOperationMetadata(proto.Message): Output only. Information further describing the output of this Model export. """ + class OutputInfo(proto.Message): r"""Further describes the output of the ExportModel. Supplements ``ExportModelRequest.OutputConfig``. @@ -330,13 +320,11 @@ class OutputInfo(proto.Message): image_output_uri = proto.Field(proto.STRING, number=3) - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - output_info = proto.Field(proto.MESSAGE, number=2, - message=OutputInfo, - ) + output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) class ExportModelResponse(proto.Message): @@ -391,9 +379,7 @@ class ListModelEvaluationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelEvaluationsResponse(proto.Message): @@ -414,8 +400,8 @@ class ListModelEvaluationsResponse(proto.Message): def raw_page(self): return self - model_evaluations = proto.RepeatedField(proto.MESSAGE, number=1, - message=model_evaluation.ModelEvaluation, + model_evaluations = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -470,9 +456,7 @@ class ListModelEvaluationSlicesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelEvaluationSlicesResponse(proto.Message): @@ -493,8 +477,8 @@ class ListModelEvaluationSlicesResponse(proto.Message): def raw_page(self): return self - model_evaluation_slices = proto.RepeatedField(proto.MESSAGE, number=1, - message=model_evaluation_slice.ModelEvaluationSlice, + model_evaluation_slices = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/operation.py b/google/cloud/aiplatform_v1/types/operation.py index 2f8211a6ad..fe24030e79 100644 --- a/google/cloud/aiplatform_v1/types/operation.py +++ b/google/cloud/aiplatform_v1/types/operation.py @@ -23,11 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'GenericOperationMetadata', - 'DeleteOperationMetadata', - }, + package="google.cloud.aiplatform.v1", + manifest={"GenericOperationMetadata", "DeleteOperationMetadata",}, ) @@ -51,17 +48,13 @@ class GenericOperationMetadata(proto.Message): finish time. """ - partial_failures = proto.RepeatedField(proto.MESSAGE, number=1, - message=status.Status, + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=1, message=status.Status, ) - create_time = proto.Field(proto.MESSAGE, number=2, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) class DeleteOperationMetadata(proto.Message): @@ -72,8 +65,8 @@ class DeleteOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message='GenericOperationMetadata', + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message="GenericOperationMetadata", ) diff --git a/google/cloud/aiplatform_v1/types/pipeline_service.py b/google/cloud/aiplatform_v1/types/pipeline_service.py index e757607527..b2c6d5bbe3 100644 --- a/google/cloud/aiplatform_v1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1/types/pipeline_service.py @@ -23,14 +23,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'CreateTrainingPipelineRequest', - 'GetTrainingPipelineRequest', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'DeleteTrainingPipelineRequest', - 'CancelTrainingPipelineRequest', + "CreateTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "DeleteTrainingPipelineRequest", + "CancelTrainingPipelineRequest", }, ) @@ -50,8 +50,8 @@ class CreateTrainingPipelineRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - training_pipeline = proto.Field(proto.MESSAGE, number=2, - message=gca_training_pipeline.TrainingPipeline, + training_pipeline = proto.Field( + proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, ) @@ -114,9 +114,7 @@ class ListTrainingPipelinesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListTrainingPipelinesResponse(proto.Message): @@ -137,8 +135,8 @@ class ListTrainingPipelinesResponse(proto.Message): def raw_page(self): return self - training_pipelines = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_training_pipeline.TrainingPipeline, + training_pipelines = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/pipeline_state.py b/google/cloud/aiplatform_v1/types/pipeline_state.py index 6a00f05fef..f6a885ae42 100644 --- a/google/cloud/aiplatform_v1/types/pipeline_state.py +++ b/google/cloud/aiplatform_v1/types/pipeline_state.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'PipelineState', - }, + package="google.cloud.aiplatform.v1", manifest={"PipelineState",}, ) diff --git a/google/cloud/aiplatform_v1/types/prediction_service.py b/google/cloud/aiplatform_v1/types/prediction_service.py index c7d39c373b..21a01372f4 100644 --- a/google/cloud/aiplatform_v1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1/types/prediction_service.py @@ -22,11 +22,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'PredictRequest', - 'PredictResponse', - }, + package="google.cloud.aiplatform.v1", + manifest={"PredictRequest", "PredictResponse",}, ) @@ -61,13 +58,9 @@ class PredictRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, - message=struct.Value, - ) + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) class PredictResponse(proto.Message): @@ -87,9 +80,7 @@ class PredictResponse(proto.Message): served this prediction. """ - predictions = proto.RepeatedField(proto.MESSAGE, number=1, - message=struct.Value, - ) + predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1/types/specialist_pool.py b/google/cloud/aiplatform_v1/types/specialist_pool.py index b57aa89666..6265316bd5 100644 --- a/google/cloud/aiplatform_v1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1/types/specialist_pool.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'SpecialistPool', - }, + package="google.cloud.aiplatform.v1", manifest={"SpecialistPool",}, ) diff --git a/google/cloud/aiplatform_v1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1/types/specialist_pool_service.py index 669756640f..69e49bb355 100644 --- a/google/cloud/aiplatform_v1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1/types/specialist_pool_service.py @@ -24,16 +24,16 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'CreateSpecialistPoolRequest', - 'CreateSpecialistPoolOperationMetadata', - 'GetSpecialistPoolRequest', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'DeleteSpecialistPoolRequest', - 'UpdateSpecialistPoolRequest', - 'UpdateSpecialistPoolOperationMetadata', + "CreateSpecialistPoolRequest", + "CreateSpecialistPoolOperationMetadata", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "DeleteSpecialistPoolRequest", + "UpdateSpecialistPoolRequest", + "UpdateSpecialistPoolOperationMetadata", }, ) @@ -53,8 +53,8 @@ class CreateSpecialistPoolRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - specialist_pool = proto.Field(proto.MESSAGE, number=2, - message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field( + proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, ) @@ -67,8 +67,8 @@ class CreateSpecialistPoolOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -114,9 +114,7 @@ class ListSpecialistPoolsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=3) - read_mask = proto.Field(proto.MESSAGE, number=4, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) class ListSpecialistPoolsResponse(proto.Message): @@ -135,8 +133,8 @@ class ListSpecialistPoolsResponse(proto.Message): def raw_page(self): return self - specialist_pools = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_specialist_pool.SpecialistPool, + specialist_pools = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -176,13 +174,11 @@ class UpdateSpecialistPoolRequest(proto.Message): resource. """ - specialist_pool = proto.Field(proto.MESSAGE, number=1, - message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateSpecialistPoolOperationMetadata(proto.Message): @@ -201,8 +197,8 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): specialist_pool = proto.Field(proto.STRING, number=1) - generic_metadata = proto.Field(proto.MESSAGE, number=2, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1/types/study.py b/google/cloud/aiplatform_v1/types/study.py index 0254866d5b..99a688f045 100644 --- a/google/cloud/aiplatform_v1/types/study.py +++ b/google/cloud/aiplatform_v1/types/study.py @@ -23,12 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'Trial', - 'StudySpec', - 'Measurement', - }, + package="google.cloud.aiplatform.v1", + manifest={"Trial", "StudySpec", "Measurement",}, ) @@ -58,6 +54,7 @@ class Trial(proto.Message): Trial. It's set for a HyperparameterTuningJob's Trial. """ + class State(proto.Enum): r"""Describes a Trial state.""" STATE_UNSPECIFIED = 0 @@ -85,31 +82,19 @@ class Parameter(proto.Message): parameter_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.MESSAGE, number=2, - message=struct.Value, - ) + value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) id = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=3, - enum=State, - ) + state = proto.Field(proto.ENUM, number=3, enum=State,) - parameters = proto.RepeatedField(proto.MESSAGE, number=4, - message=Parameter, - ) + parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) - final_measurement = proto.Field(proto.MESSAGE, number=5, - message='Measurement', - ) + final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) custom_job = proto.Field(proto.STRING, number=11) @@ -133,6 +118,7 @@ class StudySpec(proto.Message): Describe which measurement selection type will be used """ + class Algorithm(proto.Enum): r"""The available search algorithms for the Study.""" ALGORITHM_UNSPECIFIED = 0 @@ -178,6 +164,7 @@ class MetricSpec(proto.Message): Required. The optimization goal of the metric. """ + class GoalType(proto.Enum): r"""The available types of optimization goals.""" GOAL_TYPE_UNSPECIFIED = 0 @@ -186,9 +173,7 @@ class GoalType(proto.Enum): metric_id = proto.Field(proto.STRING, number=1) - goal = proto.Field(proto.ENUM, number=2, - enum='StudySpec.MetricSpec.GoalType', - ) + goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) class ParameterSpec(proto.Message): r"""Represents a single parameter to optimize. @@ -216,6 +201,7 @@ class ParameterSpec(proto.Message): If two items in conditional_parameter_specs have the same name, they must have disjoint parent_value_condition. """ + class ScaleType(proto.Enum): r"""The type of scaling that should be applied to this parameter.""" SCALE_TYPE_UNSPECIFIED = 0 @@ -298,6 +284,7 @@ class ConditionalParameterSpec(proto.Message): Required. The spec for a conditional parameter. """ + class DiscreteValueCondition(proto.Message): r"""Represents the spec to match discrete values from parent parameter. @@ -339,66 +326,81 @@ class CategoricalValueCondition(proto.Message): values = proto.RepeatedField(proto.STRING, number=1) - parent_discrete_values = proto.Field(proto.MESSAGE, number=2, oneof='parent_value_condition', - message='StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition', + parent_discrete_values = proto.Field( + proto.MESSAGE, + number=2, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition", ) - parent_int_values = proto.Field(proto.MESSAGE, number=3, oneof='parent_value_condition', - message='StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition', + parent_int_values = proto.Field( + proto.MESSAGE, + number=3, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition", ) - parent_categorical_values = proto.Field(proto.MESSAGE, number=4, oneof='parent_value_condition', - message='StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition', + parent_categorical_values = proto.Field( + proto.MESSAGE, + number=4, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition", ) - parameter_spec = proto.Field(proto.MESSAGE, number=1, - message='StudySpec.ParameterSpec', + parameter_spec = proto.Field( + proto.MESSAGE, number=1, message="StudySpec.ParameterSpec", ) - double_value_spec = proto.Field(proto.MESSAGE, number=2, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.DoubleValueSpec', + double_value_spec = proto.Field( + proto.MESSAGE, + number=2, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DoubleValueSpec", ) - integer_value_spec = proto.Field(proto.MESSAGE, number=3, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.IntegerValueSpec', + integer_value_spec = proto.Field( + proto.MESSAGE, + number=3, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.IntegerValueSpec", ) - categorical_value_spec = proto.Field(proto.MESSAGE, number=4, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.CategoricalValueSpec', + categorical_value_spec = proto.Field( + proto.MESSAGE, + number=4, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.CategoricalValueSpec", ) - discrete_value_spec = proto.Field(proto.MESSAGE, number=5, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.DiscreteValueSpec', + discrete_value_spec = proto.Field( + proto.MESSAGE, + number=5, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DiscreteValueSpec", ) parameter_id = proto.Field(proto.STRING, number=1) - scale_type = proto.Field(proto.ENUM, number=6, - enum='StudySpec.ParameterSpec.ScaleType', + scale_type = proto.Field( + proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", ) - conditional_parameter_specs = proto.RepeatedField(proto.MESSAGE, number=10, - message='StudySpec.ParameterSpec.ConditionalParameterSpec', + conditional_parameter_specs = proto.RepeatedField( + proto.MESSAGE, + number=10, + message="StudySpec.ParameterSpec.ConditionalParameterSpec", ) - metrics = proto.RepeatedField(proto.MESSAGE, number=1, - message=MetricSpec, - ) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) - parameters = proto.RepeatedField(proto.MESSAGE, number=2, - message=ParameterSpec, - ) + parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) - algorithm = proto.Field(proto.ENUM, number=3, - enum=Algorithm, - ) + algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) - observation_noise = proto.Field(proto.ENUM, number=6, - enum=ObservationNoise, - ) + observation_noise = proto.Field(proto.ENUM, number=6, enum=ObservationNoise,) - measurement_selection_type = proto.Field(proto.ENUM, number=7, - enum=MeasurementSelectionType, + measurement_selection_type = proto.Field( + proto.ENUM, number=7, enum=MeasurementSelectionType, ) @@ -417,6 +419,7 @@ class Measurement(proto.Message): evaluating the objective functions using suggested Parameter values. """ + class Metric(proto.Message): r"""A message representing a metric in the measurement. @@ -435,9 +438,7 @@ class Metric(proto.Message): step_count = proto.Field(proto.INT64, number=2) - metrics = proto.RepeatedField(proto.MESSAGE, number=3, - message=Metric, - ) + metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/training_pipeline.py b/google/cloud/aiplatform_v1/types/training_pipeline.py index b0135a926b..9a41f231a5 100644 --- a/google/cloud/aiplatform_v1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1/types/training_pipeline.py @@ -28,14 +28,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', + package="google.cloud.aiplatform.v1", manifest={ - 'TrainingPipeline', - 'InputDataConfig', - 'FractionSplit', - 'FilterSplit', - 'PredefinedSplit', - 'TimestampSplit', + "TrainingPipeline", + "InputDataConfig", + "FractionSplit", + "FilterSplit", + "PredefinedSplit", + "TimestampSplit", }, ) @@ -154,52 +154,32 @@ class TrainingPipeline(proto.Message): display_name = proto.Field(proto.STRING, number=2) - input_data_config = proto.Field(proto.MESSAGE, number=3, - message='InputDataConfig', - ) + input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) training_task_definition = proto.Field(proto.STRING, number=4) - training_task_inputs = proto.Field(proto.MESSAGE, number=5, - message=struct.Value, - ) + training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - training_task_metadata = proto.Field(proto.MESSAGE, number=6, - message=struct.Value, - ) + training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - model_to_upload = proto.Field(proto.MESSAGE, number=7, - message=model.Model, - ) + model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) - state = proto.Field(proto.ENUM, number=9, - enum=pipeline_state.PipelineState, - ) + state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) - error = proto.Field(proto.MESSAGE, number=10, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=15) - encryption_spec = proto.Field(proto.MESSAGE, number=18, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=18, message=gca_encryption_spec.EncryptionSpec, ) @@ -323,28 +303,28 @@ class InputDataConfig(proto.Message): ``annotation_schema_uri``. """ - fraction_split = proto.Field(proto.MESSAGE, number=2, oneof='split', - message='FractionSplit', + fraction_split = proto.Field( + proto.MESSAGE, number=2, oneof="split", message="FractionSplit", ) - filter_split = proto.Field(proto.MESSAGE, number=3, oneof='split', - message='FilterSplit', + filter_split = proto.Field( + proto.MESSAGE, number=3, oneof="split", message="FilterSplit", ) - predefined_split = proto.Field(proto.MESSAGE, number=4, oneof='split', - message='PredefinedSplit', + predefined_split = proto.Field( + proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", ) - timestamp_split = proto.Field(proto.MESSAGE, number=5, oneof='split', - message='TimestampSplit', + timestamp_split = proto.Field( + proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", ) - gcs_destination = proto.Field(proto.MESSAGE, number=8, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, ) - bigquery_destination = proto.Field(proto.MESSAGE, number=10, oneof='destination', - message=io.BigQueryDestination, + bigquery_destination = proto.Field( + proto.MESSAGE, number=10, oneof="destination", message=io.BigQueryDestination, ) dataset_id = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1/types/user_action_reference.py b/google/cloud/aiplatform_v1/types/user_action_reference.py index 89d799178a..da59ac6ac6 100644 --- a/google/cloud/aiplatform_v1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1/types/user_action_reference.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1', - manifest={ - 'UserActionReference', - }, + package="google.cloud.aiplatform.v1", manifest={"UserActionReference",}, ) @@ -47,9 +44,9 @@ class UserActionReference(proto.Message): "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1, oneof='reference') + operation = proto.Field(proto.STRING, number=1, oneof="reference") - data_labeling_job = proto.Field(proto.STRING, number=2, oneof='reference') + data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") method = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 0dbcbec2d6..cb2f3afb74 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -198,11 +198,19 @@ from .types.model import ModelContainerSpec from .types.model import Port from .types.model import PredictSchemata -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringBigQueryTable +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringBigQueryTable, +) from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringJob -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringObjectiveConfig -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringObjectiveType -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringScheduleConfig +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringObjectiveConfig, +) +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringObjectiveType, +) +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringScheduleConfig, +) from .types.model_deployment_monitoring_job import ModelMonitoringStatsAnomalies from .types.model_evaluation import ModelEvaluation from .types.model_evaluation_slice import ModelEvaluationSlice @@ -285,271 +293,271 @@ __all__ = ( - 'AcceleratorType', - 'ActiveLearningConfig', - 'AddContextArtifactsAndExecutionsRequest', - 'AddContextArtifactsAndExecutionsResponse', - 'AddContextChildrenRequest', - 'AddContextChildrenResponse', - 'AddExecutionEventsRequest', - 'AddExecutionEventsResponse', - 'AddTrialMeasurementRequest', - 'Annotation', - 'AnnotationSpec', - 'Artifact', - 'Attribution', - 'AutomaticResources', - 'AutoscalingMetricSpec', - 'BatchDedicatedResources', - 'BatchMigrateResourcesOperationMetadata', - 'BatchMigrateResourcesRequest', - 'BatchMigrateResourcesResponse', - 'BatchPredictionJob', - 'BigQueryDestination', - 'BigQuerySource', - 'CancelBatchPredictionJobRequest', - 'CancelCustomJobRequest', - 'CancelDataLabelingJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CancelTrainingPipelineRequest', - 'CheckTrialEarlyStoppingStateMetatdata', - 'CheckTrialEarlyStoppingStateRequest', - 'CheckTrialEarlyStoppingStateResponse', - 'CompleteTrialRequest', - 'CompletionStats', - 'ContainerRegistryDestination', - 'ContainerSpec', - 'Context', - 'CreateArtifactRequest', - 'CreateBatchPredictionJobRequest', - 'CreateContextRequest', - 'CreateCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'CreateDatasetOperationMetadata', - 'CreateDatasetRequest', - 'CreateEndpointOperationMetadata', - 'CreateEndpointRequest', - 'CreateExecutionRequest', - 'CreateHyperparameterTuningJobRequest', - 'CreateMetadataSchemaRequest', - 'CreateMetadataStoreOperationMetadata', - 'CreateMetadataStoreRequest', - 'CreateModelDeploymentMonitoringJobRequest', - 'CreateSpecialistPoolOperationMetadata', - 'CreateSpecialistPoolRequest', - 'CreateStudyRequest', - 'CreateTrainingPipelineRequest', - 'CreateTrialRequest', - 'CustomJob', - 'CustomJobSpec', - 'DataItem', - 'DataLabelingJob', - 'Dataset', - 'DatasetServiceClient', - 'DedicatedResources', - 'DeleteBatchPredictionJobRequest', - 'DeleteContextRequest', - 'DeleteCustomJobRequest', - 'DeleteDataLabelingJobRequest', - 'DeleteDatasetRequest', - 'DeleteEndpointRequest', - 'DeleteHyperparameterTuningJobRequest', - 'DeleteMetadataStoreOperationMetadata', - 'DeleteMetadataStoreRequest', - 'DeleteModelDeploymentMonitoringJobRequest', - 'DeleteModelRequest', - 'DeleteOperationMetadata', - 'DeleteSpecialistPoolRequest', - 'DeleteStudyRequest', - 'DeleteTrainingPipelineRequest', - 'DeleteTrialRequest', - 'DeployModelOperationMetadata', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployedModel', - 'DeployedModelRef', - 'DiskSpec', - 'EncryptionSpec', - 'Endpoint', - 'EndpointServiceClient', - 'EnvVar', - 'Event', - 'Execution', - 'ExplainRequest', - 'ExplainResponse', - 'Explanation', - 'ExplanationMetadata', - 'ExplanationMetadataOverride', - 'ExplanationParameters', - 'ExplanationSpec', - 'ExplanationSpecOverride', - 'ExportDataConfig', - 'ExportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportModelOperationMetadata', - 'ExportModelRequest', - 'ExportModelResponse', - 'FeatureNoiseSigma', - 'FeatureStatsAnomaly', - 'FilterSplit', - 'FractionSplit', - 'GcsDestination', - 'GcsSource', - 'GenericOperationMetadata', - 'GetAnnotationSpecRequest', - 'GetArtifactRequest', - 'GetBatchPredictionJobRequest', - 'GetContextRequest', - 'GetCustomJobRequest', - 'GetDataLabelingJobRequest', - 'GetDatasetRequest', - 'GetEndpointRequest', - 'GetExecutionRequest', - 'GetHyperparameterTuningJobRequest', - 'GetMetadataSchemaRequest', - 'GetMetadataStoreRequest', - 'GetModelDeploymentMonitoringJobRequest', - 'GetModelEvaluationRequest', - 'GetModelEvaluationSliceRequest', - 'GetModelRequest', - 'GetSpecialistPoolRequest', - 'GetStudyRequest', - 'GetTrainingPipelineRequest', - 'GetTrialRequest', - 'HyperparameterTuningJob', - 'ImportDataConfig', - 'ImportDataOperationMetadata', - 'ImportDataRequest', - 'ImportDataResponse', - 'InputDataConfig', - 'IntegratedGradientsAttribution', - 'JobServiceClient', - 'JobState', - 'LineageSubgraph', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', - 'ListArtifactsRequest', - 'ListArtifactsResponse', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'ListContextsRequest', - 'ListContextsResponse', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'ListExecutionsRequest', - 'ListExecutionsResponse', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'ListMetadataSchemasRequest', - 'ListMetadataSchemasResponse', - 'ListMetadataStoresRequest', - 'ListMetadataStoresResponse', - 'ListModelDeploymentMonitoringJobsRequest', - 'ListModelDeploymentMonitoringJobsResponse', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'ListModelsRequest', - 'ListModelsResponse', - 'ListOptimalTrialsRequest', - 'ListOptimalTrialsResponse', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'ListStudiesRequest', - 'ListStudiesResponse', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'ListTrialsRequest', - 'ListTrialsResponse', - 'LookupStudyRequest', - 'MachineSpec', - 'ManualBatchTuningParameters', - 'Measurement', - 'MetadataSchema', - 'MetadataStore', - 'MigratableResource', - 'MigrateResourceRequest', - 'MigrateResourceResponse', - 'MigrationServiceClient', - 'Model', - 'ModelContainerSpec', - 'ModelDeploymentMonitoringBigQueryTable', - 'ModelDeploymentMonitoringJob', - 'ModelDeploymentMonitoringObjectiveConfig', - 'ModelDeploymentMonitoringObjectiveType', - 'ModelDeploymentMonitoringScheduleConfig', - 'ModelEvaluation', - 'ModelEvaluationSlice', - 'ModelExplanation', - 'ModelMonitoringAlertConfig', - 'ModelMonitoringObjectiveConfig', - 'ModelMonitoringStatsAnomalies', - 'ModelServiceClient', - 'PauseModelDeploymentMonitoringJobRequest', - 'PipelineServiceClient', - 'PipelineState', - 'Port', - 'PredefinedSplit', - 'PredictRequest', - 'PredictResponse', - 'PredictSchemata', - 'PredictionServiceClient', - 'PythonPackageSpec', - 'QueryContextLineageSubgraphRequest', - 'QueryExecutionInputsAndOutputsRequest', - 'ResourcesConsumed', - 'ResumeModelDeploymentMonitoringJobRequest', - 'SampleConfig', - 'SampledShapleyAttribution', - 'SamplingStrategy', - 'Scheduling', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', - 'SmoothGradConfig', - 'SpecialistPool', - 'SpecialistPoolServiceClient', - 'StopTrialRequest', - 'Study', - 'StudySpec', - 'SuggestTrialsMetadata', - 'SuggestTrialsRequest', - 'SuggestTrialsResponse', - 'ThresholdConfig', - 'TimestampSplit', - 'TrainingConfig', - 'TrainingPipeline', - 'Trial', - 'UndeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UpdateArtifactRequest', - 'UpdateContextRequest', - 'UpdateDatasetRequest', - 'UpdateEndpointRequest', - 'UpdateExecutionRequest', - 'UpdateModelDeploymentMonitoringJobOperationMetadata', - 'UpdateModelDeploymentMonitoringJobRequest', - 'UpdateModelRequest', - 'UpdateSpecialistPoolOperationMetadata', - 'UpdateSpecialistPoolRequest', - 'UploadModelOperationMetadata', - 'UploadModelRequest', - 'UploadModelResponse', - 'UserActionReference', - 'VizierServiceClient', - 'WorkerPoolSpec', - 'XraiAttribution', -'MetadataServiceClient', + "AcceleratorType", + "ActiveLearningConfig", + "AddContextArtifactsAndExecutionsRequest", + "AddContextArtifactsAndExecutionsResponse", + "AddContextChildrenRequest", + "AddContextChildrenResponse", + "AddExecutionEventsRequest", + "AddExecutionEventsResponse", + "AddTrialMeasurementRequest", + "Annotation", + "AnnotationSpec", + "Artifact", + "Attribution", + "AutomaticResources", + "AutoscalingMetricSpec", + "BatchDedicatedResources", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "BatchPredictionJob", + "BigQueryDestination", + "BigQuerySource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CancelTrainingPipelineRequest", + "CheckTrialEarlyStoppingStateMetatdata", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CompleteTrialRequest", + "CompletionStats", + "ContainerRegistryDestination", + "ContainerSpec", + "Context", + "CreateArtifactRequest", + "CreateBatchPredictionJobRequest", + "CreateContextRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "CreateExecutionRequest", + "CreateHyperparameterTuningJobRequest", + "CreateMetadataSchemaRequest", + "CreateMetadataStoreOperationMetadata", + "CreateMetadataStoreRequest", + "CreateModelDeploymentMonitoringJobRequest", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "CreateStudyRequest", + "CreateTrainingPipelineRequest", + "CreateTrialRequest", + "CustomJob", + "CustomJobSpec", + "DataItem", + "DataLabelingJob", + "Dataset", + "DatasetServiceClient", + "DedicatedResources", + "DeleteBatchPredictionJobRequest", + "DeleteContextRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteDatasetRequest", + "DeleteEndpointRequest", + "DeleteHyperparameterTuningJobRequest", + "DeleteMetadataStoreOperationMetadata", + "DeleteMetadataStoreRequest", + "DeleteModelDeploymentMonitoringJobRequest", + "DeleteModelRequest", + "DeleteOperationMetadata", + "DeleteSpecialistPoolRequest", + "DeleteStudyRequest", + "DeleteTrainingPipelineRequest", + "DeleteTrialRequest", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "DeployedModel", + "DeployedModelRef", + "DiskSpec", + "EncryptionSpec", + "Endpoint", + "EndpointServiceClient", + "EnvVar", + "Event", + "Execution", + "ExplainRequest", + "ExplainResponse", + "Explanation", + "ExplanationMetadata", + "ExplanationMetadataOverride", + "ExplanationParameters", + "ExplanationSpec", + "ExplanationSpecOverride", + "ExportDataConfig", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "FeatureNoiseSigma", + "FeatureStatsAnomaly", + "FilterSplit", + "FractionSplit", + "GcsDestination", + "GcsSource", + "GenericOperationMetadata", + "GetAnnotationSpecRequest", + "GetArtifactRequest", + "GetBatchPredictionJobRequest", + "GetContextRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetDatasetRequest", + "GetEndpointRequest", + "GetExecutionRequest", + "GetHyperparameterTuningJobRequest", + "GetMetadataSchemaRequest", + "GetMetadataStoreRequest", + "GetModelDeploymentMonitoringJobRequest", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "GetSpecialistPoolRequest", + "GetStudyRequest", + "GetTrainingPipelineRequest", + "GetTrialRequest", + "HyperparameterTuningJob", + "ImportDataConfig", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "InputDataConfig", + "IntegratedGradientsAttribution", + "JobServiceClient", + "JobState", + "LineageSubgraph", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListArtifactsRequest", + "ListArtifactsResponse", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListContextsRequest", + "ListContextsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "ListEndpointsRequest", + "ListEndpointsResponse", + "ListExecutionsRequest", + "ListExecutionsResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "ListMetadataSchemasRequest", + "ListMetadataSchemasResponse", + "ListMetadataStoresRequest", + "ListMetadataStoresResponse", + "ListModelDeploymentMonitoringJobsRequest", + "ListModelDeploymentMonitoringJobsResponse", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "ListStudiesRequest", + "ListStudiesResponse", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "ListTrialsRequest", + "ListTrialsResponse", + "LookupStudyRequest", + "MachineSpec", + "ManualBatchTuningParameters", + "Measurement", + "MetadataSchema", + "MetadataStore", + "MigratableResource", + "MigrateResourceRequest", + "MigrateResourceResponse", + "MigrationServiceClient", + "Model", + "ModelContainerSpec", + "ModelDeploymentMonitoringBigQueryTable", + "ModelDeploymentMonitoringJob", + "ModelDeploymentMonitoringObjectiveConfig", + "ModelDeploymentMonitoringObjectiveType", + "ModelDeploymentMonitoringScheduleConfig", + "ModelEvaluation", + "ModelEvaluationSlice", + "ModelExplanation", + "ModelMonitoringAlertConfig", + "ModelMonitoringObjectiveConfig", + "ModelMonitoringStatsAnomalies", + "ModelServiceClient", + "PauseModelDeploymentMonitoringJobRequest", + "PipelineServiceClient", + "PipelineState", + "Port", + "PredefinedSplit", + "PredictRequest", + "PredictResponse", + "PredictSchemata", + "PredictionServiceClient", + "PythonPackageSpec", + "QueryContextLineageSubgraphRequest", + "QueryExecutionInputsAndOutputsRequest", + "ResourcesConsumed", + "ResumeModelDeploymentMonitoringJobRequest", + "SampleConfig", + "SampledShapleyAttribution", + "SamplingStrategy", + "Scheduling", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "SearchModelDeploymentMonitoringStatsAnomaliesRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesResponse", + "SmoothGradConfig", + "SpecialistPool", + "SpecialistPoolServiceClient", + "StopTrialRequest", + "Study", + "StudySpec", + "SuggestTrialsMetadata", + "SuggestTrialsRequest", + "SuggestTrialsResponse", + "ThresholdConfig", + "TimestampSplit", + "TrainingConfig", + "TrainingPipeline", + "Trial", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateArtifactRequest", + "UpdateContextRequest", + "UpdateDatasetRequest", + "UpdateEndpointRequest", + "UpdateExecutionRequest", + "UpdateModelDeploymentMonitoringJobOperationMetadata", + "UpdateModelDeploymentMonitoringJobRequest", + "UpdateModelRequest", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "UserActionReference", + "VizierServiceClient", + "WorkerPoolSpec", + "XraiAttribution", + "MetadataServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py index 9d1f004f6a..597f654cb9 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import DatasetServiceAsyncClient __all__ = ( - 'DatasetServiceClient', - 'DatasetServiceAsyncClient', + "DatasetServiceClient", + "DatasetServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 2eb9ce6f7a..d91df4b644 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,26 +60,42 @@ class DatasetServiceAsyncClient: annotation_path = staticmethod(DatasetServiceClient.annotation_path) parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) - parse_annotation_spec_path = staticmethod(DatasetServiceClient.parse_annotation_spec_path) + parse_annotation_spec_path = staticmethod( + DatasetServiceClient.parse_annotation_spec_path + ) data_item_path = staticmethod(DatasetServiceClient.data_item_path) parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) dataset_path = staticmethod(DatasetServiceClient.dataset_path) parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) - common_billing_account_path = staticmethod(DatasetServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(DatasetServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + DatasetServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + DatasetServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(DatasetServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + DatasetServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(DatasetServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(DatasetServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + DatasetServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + DatasetServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(DatasetServiceClient.common_project_path) - parse_common_project_path = staticmethod(DatasetServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + DatasetServiceClient.parse_common_project_path + ) common_location_path = staticmethod(DatasetServiceClient.common_location_path) - parse_common_location_path = staticmethod(DatasetServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + DatasetServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -122,14 +138,18 @@ def transport(self) -> DatasetServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient)) + get_transport_class = functools.partial( + type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -168,18 +188,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a Dataset. Args: @@ -220,8 +240,10 @@ async def create_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.CreateDatasetRequest(request) @@ -244,18 +266,11 @@ async def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -268,14 +283,15 @@ async def create_dataset(self, # Done; return the response. return response - async def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + async def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -307,8 +323,10 @@ async def get_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.GetDatasetRequest(request) @@ -329,31 +347,25 @@ async def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + async def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -398,8 +410,10 @@ async def update_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.UpdateDatasetRequest(request) @@ -422,30 +436,26 @@ async def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsAsyncPager: + async def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: r"""Lists Datasets in a Location. Args: @@ -480,8 +490,10 @@ async def list_datasets(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListDatasetsRequest(request) @@ -502,39 +514,30 @@ async def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDatasetsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Dataset. Args: @@ -580,8 +583,10 @@ async def delete_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.DeleteDatasetRequest(request) @@ -602,18 +607,11 @@ async def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -626,15 +624,16 @@ async def delete_dataset(self, # Done; return the response. return response - async def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Imports data into a Dataset. Args: @@ -678,8 +677,10 @@ async def import_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ImportDataRequest(request) @@ -703,18 +704,11 @@ async def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -727,15 +721,16 @@ async def import_data(self, # Done; return the response. return response - async def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports data from a Dataset. Args: @@ -778,8 +773,10 @@ async def export_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ExportDataRequest(request) @@ -802,18 +799,11 @@ async def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -826,14 +816,15 @@ async def export_data(self, # Done; return the response. return response - async def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsAsyncPager: + async def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: r"""Lists DataItems in a Dataset. Args: @@ -869,8 +860,10 @@ async def list_data_items(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListDataItemsRequest(request) @@ -891,39 +884,30 @@ async def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataItemsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + async def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -956,8 +940,10 @@ async def get_annotation_spec(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.GetAnnotationSpecRequest(request) @@ -978,30 +964,24 @@ async def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsAsyncPager: + async def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1037,8 +1017,10 @@ async def list_annotations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListAnnotationsRequest(request) @@ -1059,47 +1041,30 @@ async def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListAnnotationsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceAsyncClient', -) +__all__ = ("DatasetServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 9d139e6b64..37aecfc5e5 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,13 +60,14 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry['grpc'] = DatasetServiceGrpcTransport - _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[DatasetServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry["grpc"] = DatasetServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +118,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +153,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: DatasetServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,110 +169,149 @@ def transport(self) -> DatasetServiceTransport: return self._transport @staticmethod - def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: + def annotation_path( + project: str, location: str, dataset: str, data_item: str, annotation: str, + ) -> str: """Return a fully-qualified annotation string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) @staticmethod - def parse_annotation_path(path: str) -> Dict[str,str]: + def parse_annotation_path(path: str) -> Dict[str, str]: """Parse a annotation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: + def annotation_spec_path( + project: str, location: str, dataset: str, annotation_spec: str, + ) -> str: """Return a fully-qualified annotation_spec string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) @staticmethod - def parse_annotation_spec_path(path: str) -> Dict[str,str]: + def parse_annotation_spec_path(path: str) -> Dict[str, str]: """Parse a annotation_spec path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: + def data_item_path( + project: str, location: str, dataset: str, data_item: str, + ) -> str: """Return a fully-qualified data_item string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) @staticmethod - def parse_data_item_path(path: str) -> Dict[str,str]: + def parse_data_item_path(path: str) -> Dict[str, str]: """Parse a data_item path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -316,7 +355,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -326,7 +367,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -338,7 +381,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -350,8 +395,10 @@ def __init__(self, *, if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -370,15 +417,16 @@ def __init__(self, *, client_info=client_info, ) - def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a Dataset. Args: @@ -419,8 +467,10 @@ def create_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.CreateDatasetRequest. @@ -444,18 +494,11 @@ def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -468,14 +511,15 @@ def create_dataset(self, # Done; return the response. return response - def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -507,8 +551,10 @@ def get_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetDatasetRequest. @@ -530,31 +576,25 @@ def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -599,8 +639,10 @@ def update_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.UpdateDatasetRequest. @@ -624,30 +666,26 @@ def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -682,8 +720,10 @@ def list_datasets(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDatasetsRequest. @@ -705,39 +745,30 @@ def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Dataset. Args: @@ -783,8 +814,10 @@ def delete_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.DeleteDatasetRequest. @@ -806,18 +839,11 @@ def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -830,15 +856,16 @@ def delete_dataset(self, # Done; return the response. return response - def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Imports data into a Dataset. Args: @@ -882,8 +909,10 @@ def import_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ImportDataRequest. @@ -907,18 +936,11 @@ def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -931,15 +953,16 @@ def import_data(self, # Done; return the response. return response - def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports data from a Dataset. Args: @@ -982,8 +1005,10 @@ def export_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ExportDataRequest. @@ -1007,18 +1032,11 @@ def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1031,14 +1049,15 @@ def export_data(self, # Done; return the response. return response - def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -1074,8 +1093,10 @@ def list_data_items(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDataItemsRequest. @@ -1097,39 +1118,30 @@ def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -1162,8 +1174,10 @@ def get_annotation_spec(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetAnnotationSpecRequest. @@ -1185,30 +1199,24 @@ def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1244,8 +1252,10 @@ def list_annotations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListAnnotationsRequest. @@ -1267,47 +1277,30 @@ def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceClient', -) +__all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index aa9114bc5f..63560b32ba 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import data_item @@ -40,12 +49,15 @@ class ListDatasetsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListDatasetsResponse], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListDatasetsResponse], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -79,7 +91,7 @@ def __iter__(self) -> Iterable[dataset.Dataset]: yield from page.datasets def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDatasetsAsyncPager: @@ -99,12 +111,15 @@ class ListDatasetsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -142,7 +157,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataItemsPager: @@ -162,12 +177,15 @@ class ListDataItemsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListDataItemsResponse], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListDataItemsResponse], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -201,7 +219,7 @@ def __iter__(self) -> Iterable[data_item.DataItem]: yield from page.data_items def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataItemsAsyncPager: @@ -221,12 +239,15 @@ class ListDataItemsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -264,7 +285,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListAnnotationsPager: @@ -284,12 +305,15 @@ class ListAnnotationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListAnnotationsResponse], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListAnnotationsResponse], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -323,7 +347,7 @@ def __iter__(self) -> Iterable[annotation.Annotation]: yield from page.annotations def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListAnnotationsAsyncPager: @@ -343,12 +367,15 @@ class ListAnnotationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -386,4 +413,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py index 5f02a0f0d9..a4461d2ced 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] -_transport_registry['grpc'] = DatasetServiceGrpcTransport -_transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = DatasetServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport __all__ = ( - 'DatasetServiceTransport', - 'DatasetServiceGrpcTransport', - 'DatasetServiceGrpcAsyncIOTransport', + "DatasetServiceTransport", + "DatasetServiceGrpcTransport", + "DatasetServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index 74909b2980..75dc66a554 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,8 +81,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -91,17 +91,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -110,56 +112,35 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_dataset: gapic_v1.method.wrap_method( - self.create_dataset, - default_timeout=5.0, - client_info=client_info, + self.create_dataset, default_timeout=5.0, client_info=client_info, ), self.get_dataset: gapic_v1.method.wrap_method( - self.get_dataset, - default_timeout=5.0, - client_info=client_info, + self.get_dataset, default_timeout=5.0, client_info=client_info, ), self.update_dataset: gapic_v1.method.wrap_method( - self.update_dataset, - default_timeout=5.0, - client_info=client_info, + self.update_dataset, default_timeout=5.0, client_info=client_info, ), self.list_datasets: gapic_v1.method.wrap_method( - self.list_datasets, - default_timeout=5.0, - client_info=client_info, + self.list_datasets, default_timeout=5.0, client_info=client_info, ), self.delete_dataset: gapic_v1.method.wrap_method( - self.delete_dataset, - default_timeout=5.0, - client_info=client_info, + self.delete_dataset, default_timeout=5.0, client_info=client_info, ), self.import_data: gapic_v1.method.wrap_method( - self.import_data, - default_timeout=5.0, - client_info=client_info, + self.import_data, default_timeout=5.0, client_info=client_info, ), self.export_data: gapic_v1.method.wrap_method( - self.export_data, - default_timeout=5.0, - client_info=client_info, + self.export_data, default_timeout=5.0, client_info=client_info, ), self.list_data_items: gapic_v1.method.wrap_method( - self.list_data_items, - default_timeout=5.0, - client_info=client_info, + self.list_data_items, default_timeout=5.0, client_info=client_info, ), self.get_annotation_spec: gapic_v1.method.wrap_method( - self.get_annotation_spec, - default_timeout=5.0, - client_info=client_info, + self.get_annotation_spec, default_timeout=5.0, client_info=client_info, ), self.list_annotations: gapic_v1.method.wrap_method( - self.list_annotations, - default_timeout=5.0, - client_info=client_info, + self.list_annotations, default_timeout=5.0, client_info=client_info, ), - } @property @@ -168,96 +149,106 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_dataset(self) -> typing.Callable[ - [dataset_service.CreateDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_dataset( + self, + ) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_dataset(self) -> typing.Callable[ - [dataset_service.GetDatasetRequest], - typing.Union[ - dataset.Dataset, - typing.Awaitable[dataset.Dataset] - ]]: + def get_dataset( + self, + ) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], + ]: raise NotImplementedError() @property - def update_dataset(self) -> typing.Callable[ - [dataset_service.UpdateDatasetRequest], - typing.Union[ - gca_dataset.Dataset, - typing.Awaitable[gca_dataset.Dataset] - ]]: + def update_dataset( + self, + ) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], + ]: raise NotImplementedError() @property - def list_datasets(self) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], - typing.Union[ - dataset_service.ListDatasetsResponse, - typing.Awaitable[dataset_service.ListDatasetsResponse] - ]]: + def list_datasets( + self, + ) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse], + ], + ]: raise NotImplementedError() @property - def delete_dataset(self) -> typing.Callable[ - [dataset_service.DeleteDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_dataset( + self, + ) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def import_data(self) -> typing.Callable[ - [dataset_service.ImportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def import_data( + self, + ) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_data(self) -> typing.Callable[ - [dataset_service.ExportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_data( + self, + ) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def list_data_items(self) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], - typing.Union[ - dataset_service.ListDataItemsResponse, - typing.Awaitable[dataset_service.ListDataItemsResponse] - ]]: + def list_data_items( + self, + ) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse], + ], + ]: raise NotImplementedError() @property - def get_annotation_spec(self) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], - typing.Union[ - annotation_spec.AnnotationSpec, - typing.Awaitable[annotation_spec.AnnotationSpec] - ]]: + def get_annotation_spec( + self, + ) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec], + ], + ]: raise NotImplementedError() @property - def list_annotations(self) -> typing.Callable[ - [dataset_service.ListAnnotationsRequest], - typing.Union[ - dataset_service.ListAnnotationsResponse, - typing.Awaitable[dataset_service.ListAnnotationsResponse] - ]]: + def list_annotations( + self, + ) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'DatasetServiceTransport', -) +__all__ = ("DatasetServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 39f0405cfa..ca597a1e69 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -46,21 +46,24 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -172,13 +175,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -211,7 +216,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -229,17 +234,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_dataset(self) -> Callable[ - [dataset_service.CreateDatasetRequest], - operations.Operation]: + def create_dataset( + self, + ) -> Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -254,18 +257,18 @@ def create_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_dataset' not in self._stubs: - self._stubs['create_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_dataset'] + return self._stubs["create_dataset"] @property - def get_dataset(self) -> Callable[ - [dataset_service.GetDatasetRequest], - dataset.Dataset]: + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -280,18 +283,18 @@ def get_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_dataset' not in self._stubs: - self._stubs['get_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs['get_dataset'] + return self._stubs["get_dataset"] @property - def update_dataset(self) -> Callable[ - [dataset_service.UpdateDatasetRequest], - gca_dataset.Dataset]: + def update_dataset( + self, + ) -> Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -306,18 +309,20 @@ def update_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_dataset' not in self._stubs: - self._stubs['update_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs['update_dataset'] + return self._stubs["update_dataset"] @property - def list_datasets(self) -> Callable[ - [dataset_service.ListDatasetsRequest], - dataset_service.ListDatasetsResponse]: + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse + ]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -332,18 +337,18 @@ def list_datasets(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_datasets' not in self._stubs: - self._stubs['list_datasets'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs['list_datasets'] + return self._stubs["list_datasets"] @property - def delete_dataset(self) -> Callable[ - [dataset_service.DeleteDatasetRequest], - operations.Operation]: + def delete_dataset( + self, + ) -> Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -358,18 +363,18 @@ def delete_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_dataset' not in self._stubs: - self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_dataset'] + return self._stubs["delete_dataset"] @property - def import_data(self) -> Callable[ - [dataset_service.ImportDataRequest], - operations.Operation]: + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], operations.Operation]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -384,18 +389,18 @@ def import_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_data' not in self._stubs: - self._stubs['import_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_data'] + return self._stubs["import_data"] @property - def export_data(self) -> Callable[ - [dataset_service.ExportDataRequest], - operations.Operation]: + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], operations.Operation]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -410,18 +415,20 @@ def export_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_data' not in self._stubs: - self._stubs['export_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_data'] + return self._stubs["export_data"] @property - def list_data_items(self) -> Callable[ - [dataset_service.ListDataItemsRequest], - dataset_service.ListDataItemsResponse]: + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse + ]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -436,18 +443,20 @@ def list_data_items(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_items' not in self._stubs: - self._stubs['list_data_items'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs['list_data_items'] + return self._stubs["list_data_items"] @property - def get_annotation_spec(self) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - annotation_spec.AnnotationSpec]: + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec + ]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -462,18 +471,21 @@ def get_annotation_spec(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_annotation_spec' not in self._stubs: - self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs['get_annotation_spec'] + return self._stubs["get_annotation_spec"] @property - def list_annotations(self) -> Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse]: + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse, + ]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -488,15 +500,13 @@ def list_annotations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_annotations' not in self._stubs: - self._stubs['list_annotations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs['list_annotations'] + return self._stubs["list_annotations"] -__all__ = ( - 'DatasetServiceGrpcTransport', -) +__all__ = ("DatasetServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index 6ed4e0785b..f51fe3bf1b 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import annotation_spec @@ -53,13 +53,15 @@ class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -88,22 +90,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -242,9 +246,11 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_dataset(self) -> Callable[ - [dataset_service.CreateDatasetRequest], - Awaitable[operations.Operation]]: + def create_dataset( + self, + ) -> Callable[ + [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -259,18 +265,18 @@ def create_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_dataset' not in self._stubs: - self._stubs['create_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_dataset'] + return self._stubs["create_dataset"] @property - def get_dataset(self) -> Callable[ - [dataset_service.GetDatasetRequest], - Awaitable[dataset.Dataset]]: + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -285,18 +291,20 @@ def get_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_dataset' not in self._stubs: - self._stubs['get_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs['get_dataset'] + return self._stubs["get_dataset"] @property - def update_dataset(self) -> Callable[ - [dataset_service.UpdateDatasetRequest], - Awaitable[gca_dataset.Dataset]]: + def update_dataset( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] + ]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -311,18 +319,21 @@ def update_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_dataset' not in self._stubs: - self._stubs['update_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs['update_dataset'] + return self._stubs["update_dataset"] @property - def list_datasets(self) -> Callable[ - [dataset_service.ListDatasetsRequest], - Awaitable[dataset_service.ListDatasetsResponse]]: + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse], + ]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -337,18 +348,20 @@ def list_datasets(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_datasets' not in self._stubs: - self._stubs['list_datasets'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs['list_datasets'] + return self._stubs["list_datasets"] @property - def delete_dataset(self) -> Callable[ - [dataset_service.DeleteDatasetRequest], - Awaitable[operations.Operation]]: + def delete_dataset( + self, + ) -> Callable[ + [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -363,18 +376,18 @@ def delete_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_dataset' not in self._stubs: - self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_dataset'] + return self._stubs["delete_dataset"] @property - def import_data(self) -> Callable[ - [dataset_service.ImportDataRequest], - Awaitable[operations.Operation]]: + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -389,18 +402,18 @@ def import_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_data' not in self._stubs: - self._stubs['import_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_data'] + return self._stubs["import_data"] @property - def export_data(self) -> Callable[ - [dataset_service.ExportDataRequest], - Awaitable[operations.Operation]]: + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -415,18 +428,21 @@ def export_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_data' not in self._stubs: - self._stubs['export_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_data'] + return self._stubs["export_data"] @property - def list_data_items(self) -> Callable[ - [dataset_service.ListDataItemsRequest], - Awaitable[dataset_service.ListDataItemsResponse]]: + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse], + ]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -441,18 +457,21 @@ def list_data_items(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_items' not in self._stubs: - self._stubs['list_data_items'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs['list_data_items'] + return self._stubs["list_data_items"] @property - def get_annotation_spec(self) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - Awaitable[annotation_spec.AnnotationSpec]]: + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec], + ]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -467,18 +486,21 @@ def get_annotation_spec(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_annotation_spec' not in self._stubs: - self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs['get_annotation_spec'] + return self._stubs["get_annotation_spec"] @property - def list_annotations(self) -> Callable[ - [dataset_service.ListAnnotationsRequest], - Awaitable[dataset_service.ListAnnotationsResponse]]: + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse], + ]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -493,15 +515,13 @@ def list_annotations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_annotations' not in self._stubs: - self._stubs['list_annotations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs['list_annotations'] + return self._stubs["list_annotations"] -__all__ = ( - 'DatasetServiceGrpcAsyncIOTransport', -) +__all__ = ("DatasetServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py index e4f3dcfbcf..035a5b2388 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import EndpointServiceAsyncClient __all__ = ( - 'EndpointServiceClient', - 'EndpointServiceAsyncClient', + "EndpointServiceClient", + "EndpointServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index daadc92c9e..05aa538225 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -58,20 +58,34 @@ class EndpointServiceAsyncClient: model_path = staticmethod(EndpointServiceClient.model_path) parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) - common_billing_account_path = staticmethod(EndpointServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(EndpointServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + EndpointServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + EndpointServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(EndpointServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + EndpointServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(EndpointServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(EndpointServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + EndpointServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + EndpointServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(EndpointServiceClient.common_project_path) - parse_common_project_path = staticmethod(EndpointServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + EndpointServiceClient.parse_common_project_path + ) common_location_path = staticmethod(EndpointServiceClient.common_location_path) - parse_common_location_path = staticmethod(EndpointServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + EndpointServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -114,14 +128,18 @@ def transport(self) -> EndpointServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient)) + get_transport_class = functools.partial( + type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -160,18 +178,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates an Endpoint. Args: @@ -211,8 +229,10 @@ async def create_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.CreateEndpointRequest(request) @@ -235,18 +255,11 @@ async def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -259,14 +272,15 @@ async def create_endpoint(self, # Done; return the response. return response - async def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + async def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -299,8 +313,10 @@ async def get_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.GetEndpointRequest(request) @@ -321,30 +337,24 @@ async def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsAsyncPager: + async def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: r"""Lists Endpoints in a Location. Args: @@ -380,8 +390,10 @@ async def list_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.ListEndpointsRequest(request) @@ -402,40 +414,31 @@ async def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListEndpointsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + async def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -475,8 +478,10 @@ async def update_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.UpdateEndpointRequest(request) @@ -499,30 +504,26 @@ async def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes an Endpoint. Args: @@ -568,8 +569,10 @@ async def delete_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.DeleteEndpointRequest(request) @@ -590,18 +593,11 @@ async def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -614,16 +610,19 @@ async def delete_endpoint(self, # Done; return the response. return response - async def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -692,8 +691,10 @@ async def deploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.DeployModelRequest(request) @@ -719,18 +720,11 @@ async def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -743,16 +737,19 @@ async def deploy_model(self, # Done; return the response. return response - async def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -812,8 +809,10 @@ async def undeploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.UndeployModelRequest(request) @@ -839,18 +838,11 @@ async def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -864,21 +856,14 @@ async def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceAsyncClient', -) +__all__ = ("EndpointServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 78822a9489..1fdf1e506e 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -56,13 +56,14 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry['grpc'] = EndpointServiceGrpcTransport - _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[EndpointServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry["grpc"] = EndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -113,7 +114,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -148,9 +149,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: EndpointServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -165,88 +165,104 @@ def transport(self) -> EndpointServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -290,7 +306,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -300,7 +318,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -312,7 +332,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -324,8 +346,10 @@ def __init__(self, *, if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -344,15 +368,16 @@ def __init__(self, *, client_info=client_info, ) - def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates an Endpoint. Args: @@ -392,8 +417,10 @@ def create_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.CreateEndpointRequest. @@ -417,18 +444,11 @@ def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -441,14 +461,15 @@ def create_endpoint(self, # Done; return the response. return response - def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -481,8 +502,10 @@ def get_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.GetEndpointRequest. @@ -504,30 +527,24 @@ def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -563,8 +580,10 @@ def list_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.ListEndpointsRequest. @@ -586,40 +605,31 @@ def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -659,8 +669,10 @@ def update_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UpdateEndpointRequest. @@ -684,30 +696,26 @@ def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes an Endpoint. Args: @@ -753,8 +761,10 @@ def delete_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeleteEndpointRequest. @@ -776,18 +786,11 @@ def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -800,16 +803,19 @@ def delete_endpoint(self, # Done; return the response. return response - def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -878,8 +884,10 @@ def deploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeployModelRequest. @@ -905,18 +913,11 @@ def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -929,16 +930,19 @@ def deploy_model(self, # Done; return the response. return response - def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -998,8 +1002,10 @@ def undeploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UndeployModelRequest. @@ -1025,18 +1031,11 @@ def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1050,21 +1049,14 @@ def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceClient', -) +__all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 4261cca3fb..db3172bcef 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -38,12 +47,15 @@ class ListEndpointsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., endpoint_service.ListEndpointsResponse], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., endpoint_service.ListEndpointsResponse], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: yield from page.endpoints def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListEndpointsAsyncPager: @@ -97,12 +109,15 @@ class ListEndpointsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -140,4 +155,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py index eb2ef767fe..3d0695461d 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] -_transport_registry['grpc'] = EndpointServiceGrpcTransport -_transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = EndpointServiceGrpcTransport +_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport __all__ = ( - 'EndpointServiceTransport', - 'EndpointServiceGrpcTransport', - 'EndpointServiceGrpcAsyncIOTransport', + "EndpointServiceTransport", + "EndpointServiceGrpcTransport", + "EndpointServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index 85c53f94e3..9ff0668d04 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,29 +35,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -80,8 +80,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -90,17 +90,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -109,41 +111,26 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_endpoint: gapic_v1.method.wrap_method( - self.create_endpoint, - default_timeout=5.0, - client_info=client_info, + self.create_endpoint, default_timeout=5.0, client_info=client_info, ), self.get_endpoint: gapic_v1.method.wrap_method( - self.get_endpoint, - default_timeout=5.0, - client_info=client_info, + self.get_endpoint, default_timeout=5.0, client_info=client_info, ), self.list_endpoints: gapic_v1.method.wrap_method( - self.list_endpoints, - default_timeout=5.0, - client_info=client_info, + self.list_endpoints, default_timeout=5.0, client_info=client_info, ), self.update_endpoint: gapic_v1.method.wrap_method( - self.update_endpoint, - default_timeout=5.0, - client_info=client_info, + self.update_endpoint, default_timeout=5.0, client_info=client_info, ), self.delete_endpoint: gapic_v1.method.wrap_method( - self.delete_endpoint, - default_timeout=5.0, - client_info=client_info, + self.delete_endpoint, default_timeout=5.0, client_info=client_info, ), self.deploy_model: gapic_v1.method.wrap_method( - self.deploy_model, - default_timeout=5.0, - client_info=client_info, + self.deploy_model, default_timeout=5.0, client_info=client_info, ), self.undeploy_model: gapic_v1.method.wrap_method( - self.undeploy_model, - default_timeout=5.0, - client_info=client_info, + self.undeploy_model, default_timeout=5.0, client_info=client_info, ), - } @property @@ -152,69 +139,70 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_endpoint(self) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_endpoint(self) -> typing.Callable[ - [endpoint_service.GetEndpointRequest], - typing.Union[ - endpoint.Endpoint, - typing.Awaitable[endpoint.Endpoint] - ]]: + def get_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def list_endpoints(self) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], - typing.Union[ - endpoint_service.ListEndpointsResponse, - typing.Awaitable[endpoint_service.ListEndpointsResponse] - ]]: + def list_endpoints( + self, + ) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse], + ], + ]: raise NotImplementedError() @property - def update_endpoint(self) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], - typing.Union[ - gca_endpoint.Endpoint, - typing.Awaitable[gca_endpoint.Endpoint] - ]]: + def update_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def delete_endpoint(self) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def deploy_model(self) -> typing.Callable[ - [endpoint_service.DeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def deploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def undeploy_model(self) -> typing.Callable[ - [endpoint_service.UndeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def undeploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'EndpointServiceTransport', -) +__all__ = ("EndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 555432fec0..8943c2f3f0 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -45,21 +45,24 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -171,13 +174,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -210,7 +215,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -228,17 +233,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_endpoint(self) -> Callable[ - [endpoint_service.CreateEndpointRequest], - operations.Operation]: + def create_endpoint( + self, + ) -> Callable[[endpoint_service.CreateEndpointRequest], operations.Operation]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -253,18 +256,18 @@ def create_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_endpoint' not in self._stubs: - self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_endpoint'] + return self._stubs["create_endpoint"] @property - def get_endpoint(self) -> Callable[ - [endpoint_service.GetEndpointRequest], - endpoint.Endpoint]: + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -279,18 +282,20 @@ def get_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_endpoint' not in self._stubs: - self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs['get_endpoint'] + return self._stubs["get_endpoint"] @property - def list_endpoints(self) -> Callable[ - [endpoint_service.ListEndpointsRequest], - endpoint_service.ListEndpointsResponse]: + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse + ]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -305,18 +310,18 @@ def list_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_endpoints' not in self._stubs: - self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs['list_endpoints'] + return self._stubs["list_endpoints"] @property - def update_endpoint(self) -> Callable[ - [endpoint_service.UpdateEndpointRequest], - gca_endpoint.Endpoint]: + def update_endpoint( + self, + ) -> Callable[[endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -331,18 +336,18 @@ def update_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_endpoint' not in self._stubs: - self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs['update_endpoint'] + return self._stubs["update_endpoint"] @property - def delete_endpoint(self) -> Callable[ - [endpoint_service.DeleteEndpointRequest], - operations.Operation]: + def delete_endpoint( + self, + ) -> Callable[[endpoint_service.DeleteEndpointRequest], operations.Operation]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -357,18 +362,18 @@ def delete_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_endpoint' not in self._stubs: - self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_endpoint'] + return self._stubs["delete_endpoint"] @property - def deploy_model(self) -> Callable[ - [endpoint_service.DeployModelRequest], - operations.Operation]: + def deploy_model( + self, + ) -> Callable[[endpoint_service.DeployModelRequest], operations.Operation]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -384,18 +389,18 @@ def deploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_model' not in self._stubs: - self._stubs['deploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_model'] + return self._stubs["deploy_model"] @property - def undeploy_model(self) -> Callable[ - [endpoint_service.UndeployModelRequest], - operations.Operation]: + def undeploy_model( + self, + ) -> Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -412,15 +417,13 @@ def undeploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_model' not in self._stubs: - self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_model'] + return self._stubs["undeploy_model"] -__all__ = ( - 'EndpointServiceGrpcTransport', -) +__all__ = ("EndpointServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index 1c5fe7e1f4..141168146d 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import endpoint @@ -52,13 +52,15 @@ class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -87,22 +89,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -241,9 +245,11 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_endpoint(self) -> Callable[ - [endpoint_service.CreateEndpointRequest], - Awaitable[operations.Operation]]: + def create_endpoint( + self, + ) -> Callable[ + [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -258,18 +264,18 @@ def create_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_endpoint' not in self._stubs: - self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_endpoint'] + return self._stubs["create_endpoint"] @property - def get_endpoint(self) -> Callable[ - [endpoint_service.GetEndpointRequest], - Awaitable[endpoint.Endpoint]]: + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -284,18 +290,21 @@ def get_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_endpoint' not in self._stubs: - self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs['get_endpoint'] + return self._stubs["get_endpoint"] @property - def list_endpoints(self) -> Callable[ - [endpoint_service.ListEndpointsRequest], - Awaitable[endpoint_service.ListEndpointsResponse]]: + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse], + ]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -310,18 +319,20 @@ def list_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_endpoints' not in self._stubs: - self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs['list_endpoints'] + return self._stubs["list_endpoints"] @property - def update_endpoint(self) -> Callable[ - [endpoint_service.UpdateEndpointRequest], - Awaitable[gca_endpoint.Endpoint]]: + def update_endpoint( + self, + ) -> Callable[ + [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] + ]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -336,18 +347,20 @@ def update_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_endpoint' not in self._stubs: - self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs['update_endpoint'] + return self._stubs["update_endpoint"] @property - def delete_endpoint(self) -> Callable[ - [endpoint_service.DeleteEndpointRequest], - Awaitable[operations.Operation]]: + def delete_endpoint( + self, + ) -> Callable[ + [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -362,18 +375,20 @@ def delete_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_endpoint' not in self._stubs: - self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_endpoint'] + return self._stubs["delete_endpoint"] @property - def deploy_model(self) -> Callable[ - [endpoint_service.DeployModelRequest], - Awaitable[operations.Operation]]: + def deploy_model( + self, + ) -> Callable[ + [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -389,18 +404,20 @@ def deploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_model' not in self._stubs: - self._stubs['deploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_model'] + return self._stubs["deploy_model"] @property - def undeploy_model(self) -> Callable[ - [endpoint_service.UndeployModelRequest], - Awaitable[operations.Operation]]: + def undeploy_model( + self, + ) -> Callable[ + [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -417,15 +434,13 @@ def undeploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_model' not in self._stubs: - self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_model'] + return self._stubs["undeploy_model"] -__all__ = ( - 'EndpointServiceGrpcAsyncIOTransport', -) +__all__ = ("EndpointServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py index 037407b714..5f157047f5 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import JobServiceAsyncClient __all__ = ( - 'JobServiceClient', - 'JobServiceAsyncClient', + "JobServiceClient", + "JobServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 8b0e8331bb..11eb7a1c8b 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -21,34 +21,42 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.cloud.aiplatform_v1beta1.types import model_monitoring from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import study @@ -74,38 +82,58 @@ class JobServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) - parse_batch_prediction_job_path = staticmethod(JobServiceClient.parse_batch_prediction_job_path) + parse_batch_prediction_job_path = staticmethod( + JobServiceClient.parse_batch_prediction_job_path + ) custom_job_path = staticmethod(JobServiceClient.custom_job_path) parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) - parse_data_labeling_job_path = staticmethod(JobServiceClient.parse_data_labeling_job_path) + parse_data_labeling_job_path = staticmethod( + JobServiceClient.parse_data_labeling_job_path + ) dataset_path = staticmethod(JobServiceClient.dataset_path) parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) endpoint_path = staticmethod(JobServiceClient.endpoint_path) parse_endpoint_path = staticmethod(JobServiceClient.parse_endpoint_path) - hyperparameter_tuning_job_path = staticmethod(JobServiceClient.hyperparameter_tuning_job_path) - parse_hyperparameter_tuning_job_path = staticmethod(JobServiceClient.parse_hyperparameter_tuning_job_path) + hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.hyperparameter_tuning_job_path + ) + parse_hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.parse_hyperparameter_tuning_job_path + ) model_path = staticmethod(JobServiceClient.model_path) parse_model_path = staticmethod(JobServiceClient.parse_model_path) - model_deployment_monitoring_job_path = staticmethod(JobServiceClient.model_deployment_monitoring_job_path) - parse_model_deployment_monitoring_job_path = staticmethod(JobServiceClient.parse_model_deployment_monitoring_job_path) + model_deployment_monitoring_job_path = staticmethod( + JobServiceClient.model_deployment_monitoring_job_path + ) + parse_model_deployment_monitoring_job_path = staticmethod( + JobServiceClient.parse_model_deployment_monitoring_job_path + ) trial_path = staticmethod(JobServiceClient.trial_path) parse_trial_path = staticmethod(JobServiceClient.parse_trial_path) - common_billing_account_path = staticmethod(JobServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(JobServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + JobServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + JobServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(JobServiceClient.common_folder_path) parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) common_organization_path = staticmethod(JobServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(JobServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + JobServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(JobServiceClient.common_project_path) parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) common_location_path = staticmethod(JobServiceClient.common_location_path) - parse_common_location_path = staticmethod(JobServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + JobServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -148,14 +176,18 @@ def transport(self) -> JobServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(JobServiceClient).get_transport_class, type(JobServiceClient)) + get_transport_class = functools.partial( + type(JobServiceClient).get_transport_class, type(JobServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -194,18 +226,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + async def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -250,8 +282,10 @@ async def create_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateCustomJobRequest(request) @@ -274,30 +308,24 @@ async def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + async def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -335,8 +363,10 @@ async def get_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetCustomJobRequest(request) @@ -357,30 +387,24 @@ async def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsAsyncPager: + async def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: r"""Lists CustomJobs in a Location. Args: @@ -416,8 +440,10 @@ async def list_custom_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListCustomJobsRequest(request) @@ -438,39 +464,30 @@ async def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListCustomJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a CustomJob. Args: @@ -516,8 +533,10 @@ async def delete_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteCustomJobRequest(request) @@ -538,18 +557,11 @@ async def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -562,14 +574,15 @@ async def delete_custom_job(self, # Done; return the response. return response - async def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -607,8 +620,10 @@ async def cancel_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelCustomJobRequest(request) @@ -629,28 +644,24 @@ async def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -690,8 +701,10 @@ async def create_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateDataLabelingJobRequest(request) @@ -714,30 +727,24 @@ async def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + async def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -770,8 +777,10 @@ async def get_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetDataLabelingJobRequest(request) @@ -792,30 +801,24 @@ async def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsAsyncPager: + async def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -850,8 +853,10 @@ async def list_data_labeling_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListDataLabelingJobsRequest(request) @@ -872,39 +877,30 @@ async def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataLabelingJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a DataLabelingJob. Args: @@ -950,8 +946,10 @@ async def delete_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteDataLabelingJobRequest(request) @@ -972,18 +970,11 @@ async def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -996,14 +987,15 @@ async def delete_data_labeling_job(self, # Done; return the response. return response - async def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1030,8 +1022,10 @@ async def cancel_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelDataLabelingJobRequest(request) @@ -1052,28 +1046,24 @@ async def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1115,8 +1105,10 @@ async def create_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateHyperparameterTuningJobRequest(request) @@ -1139,30 +1131,24 @@ async def create_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + async def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1197,8 +1183,10 @@ async def get_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetHyperparameterTuningJobRequest(request) @@ -1219,30 +1207,24 @@ async def get_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + async def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1278,8 +1260,10 @@ async def list_hyperparameter_tuning_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListHyperparameterTuningJobsRequest(request) @@ -1300,39 +1284,30 @@ async def list_hyperparameter_tuning_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListHyperparameterTuningJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1378,8 +1353,10 @@ async def delete_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteHyperparameterTuningJobRequest(request) @@ -1400,18 +1377,11 @@ async def delete_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1424,14 +1394,15 @@ async def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - async def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1471,8 +1442,10 @@ async def cancel_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelHyperparameterTuningJobRequest(request) @@ -1493,28 +1466,24 @@ async def cancel_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1559,8 +1528,10 @@ async def create_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateBatchPredictionJobRequest(request) @@ -1583,30 +1554,24 @@ async def create_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + async def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1643,8 +1608,10 @@ async def get_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetBatchPredictionJobRequest(request) @@ -1665,30 +1632,24 @@ async def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsAsyncPager: + async def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1724,8 +1685,10 @@ async def list_batch_prediction_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListBatchPredictionJobsRequest(request) @@ -1746,39 +1709,30 @@ async def list_batch_prediction_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListBatchPredictionJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -1825,8 +1779,10 @@ async def delete_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteBatchPredictionJobRequest(request) @@ -1847,18 +1803,11 @@ async def delete_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1871,14 +1820,15 @@ async def delete_batch_prediction_job(self, # Done; return the response. return response - async def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -1916,8 +1866,10 @@ async def cancel_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelBatchPredictionJobRequest(request) @@ -1938,28 +1890,24 @@ async def cancel_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_model_deployment_monitoring_job(self, - request: job_service.CreateModelDeploymentMonitoringJobRequest = None, - *, - parent: str = None, - model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_model_deployment_monitoring_job( + self, + request: job_service.CreateModelDeploymentMonitoringJobRequest = None, + *, + parent: str = None, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: r"""Creates a ModelDeploymentMonitoringJob. It will run periodically on a configured interval. @@ -2003,8 +1951,10 @@ async def create_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateModelDeploymentMonitoringJobRequest(request) @@ -2027,31 +1977,25 @@ async def create_model_deployment_monitoring_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def search_model_deployment_monitoring_stats_anomalies(self, - request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, - *, - model_deployment_monitoring_job: str = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: + async def search_model_deployment_monitoring_stats_anomalies( + self, + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, + *, + model_deployment_monitoring_job: str = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: r"""Searches Model Monitoring Statistics generated within a given time window. @@ -2095,10 +2039,14 @@ async def search_model_deployment_monitoring_stats_anomalies(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) - request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2119,39 +2067,37 @@ async def search_model_deployment_monitoring_stats_anomalies(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model_deployment_monitoring_job', request.model_deployment_monitoring_job), - )), + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "model_deployment_monitoring_job", + request.model_deployment_monitoring_job, + ), + ) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def get_model_deployment_monitoring_job(self, - request: job_service.GetModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + async def get_model_deployment_monitoring_job( + self, + request: job_service.GetModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: r"""Gets a ModelDeploymentMonitoringJob. Args: @@ -2187,8 +2133,10 @@ async def get_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetModelDeploymentMonitoringJobRequest(request) @@ -2209,30 +2157,24 @@ async def get_model_deployment_monitoring_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_model_deployment_monitoring_jobs(self, - request: job_service.ListModelDeploymentMonitoringJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelDeploymentMonitoringJobsAsyncPager: + async def list_model_deployment_monitoring_jobs( + self, + request: job_service.ListModelDeploymentMonitoringJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelDeploymentMonitoringJobsAsyncPager: r"""Lists ModelDeploymentMonitoringJobs in a Location. Args: @@ -2268,8 +2210,10 @@ async def list_model_deployment_monitoring_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListModelDeploymentMonitoringJobsRequest(request) @@ -2290,40 +2234,31 @@ async def list_model_deployment_monitoring_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelDeploymentMonitoringJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_model_deployment_monitoring_job(self, - request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, - *, - model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_model_deployment_monitoring_job( + self, + request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, + *, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates a ModelDeploymentMonitoringJob. Args: @@ -2366,8 +2301,10 @@ async def update_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) @@ -2390,18 +2327,18 @@ async def update_model_deployment_monitoring_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model_deployment_monitoring_job.name', request.model_deployment_monitoring_job.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "model_deployment_monitoring_job.name", + request.model_deployment_monitoring_job.name, + ), + ) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -2414,14 +2351,15 @@ async def update_model_deployment_monitoring_job(self, # Done; return the response. return response - async def delete_model_deployment_monitoring_job(self, - request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_model_deployment_monitoring_job( + self, + request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a ModelDeploymentMonitoringJob. Args: @@ -2467,8 +2405,10 @@ async def delete_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) @@ -2489,18 +2429,11 @@ async def delete_model_deployment_monitoring_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -2513,14 +2446,15 @@ async def delete_model_deployment_monitoring_job(self, # Done; return the response. return response - async def pause_model_deployment_monitoring_job(self, - request: job_service.PauseModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def pause_model_deployment_monitoring_job( + self, + request: job_service.PauseModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Pauses a ModelDeploymentMonitoringJob. If the job is running, the server makes a best effort to cancel the job. Will mark ``ModelDeploymentMonitoringJob.state`` @@ -2550,8 +2484,10 @@ async def pause_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.PauseModelDeploymentMonitoringJobRequest(request) @@ -2572,27 +2508,23 @@ async def pause_model_deployment_monitoring_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def resume_model_deployment_monitoring_job(self, - request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def resume_model_deployment_monitoring_job( + self, + request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Resumes a paused ModelDeploymentMonitoringJob. It will start to run from next scheduled time. A deleted ModelDeploymentMonitoringJob can't be resumed. @@ -2621,8 +2553,10 @@ async def resume_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) @@ -2643,35 +2577,23 @@ async def resume_model_deployment_monitoring_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceAsyncClient', -) +__all__ = ("JobServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index cb4d402b6a..aa08265c28 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -23,36 +23,44 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.cloud.aiplatform_v1beta1.types import model_monitoring from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import study @@ -76,13 +84,12 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry['grpc'] = JobServiceGrpcTransport - _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[JobServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -133,7 +140,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -168,9 +175,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: JobServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -185,165 +191,230 @@ def transport(self) -> JobServiceTransport: return self._transport @staticmethod - def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: + def batch_prediction_job_path( + project: str, location: str, batch_prediction_job: str, + ) -> str: """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, + location=location, + batch_prediction_job=batch_prediction_job, + ) @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: + def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: """Parse a batch_prediction_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def custom_job_path(project: str,location: str,custom_job: str,) -> str: + def custom_job_path(project: str, location: str, custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str,str]: + def parse_custom_job_path(path: str) -> Dict[str, str]: """Parse a custom_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: + def data_labeling_job_path( + project: str, location: str, data_labeling_job: str, + ) -> str: """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str,str]: + def parse_data_labeling_job_path(path: str) -> Dict[str, str]: """Parse a data_labeling_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: + def hyperparameter_tuning_job_path( + project: str, location: str, hyperparameter_tuning_job: str, + ) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_deployment_monitoring_job_path(project: str,location: str,model_deployment_monitoring_job: str,) -> str: + def model_deployment_monitoring_job_path( + project: str, location: str, model_deployment_monitoring_job: str, + ) -> str: """Return a fully-qualified model_deployment_monitoring_job string.""" - return "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format(project=project, location=location, model_deployment_monitoring_job=model_deployment_monitoring_job, ) + return "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format( + project=project, + location=location, + model_deployment_monitoring_job=model_deployment_monitoring_job, + ) @staticmethod - def parse_model_deployment_monitoring_job_path(path: str) -> Dict[str,str]: + def parse_model_deployment_monitoring_job_path(path: str) -> Dict[str, str]: """Parse a model_deployment_monitoring_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/modelDeploymentMonitoringJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/modelDeploymentMonitoringJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str,location: str,study: str,trial: str,) -> str: + def trial_path(project: str, location: str, study: str, trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) @staticmethod - def parse_trial_path(path: str) -> Dict[str,str]: + def parse_trial_path(path: str) -> Dict[str, str]: """Parse a trial path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -387,7 +458,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -397,7 +470,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -409,7 +484,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -421,8 +498,10 @@ def __init__(self, *, if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -441,15 +520,16 @@ def __init__(self, *, client_info=client_info, ) - def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -494,8 +574,10 @@ def create_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateCustomJobRequest. @@ -519,30 +601,24 @@ def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -580,8 +656,10 @@ def get_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetCustomJobRequest. @@ -603,30 +681,24 @@ def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -662,8 +734,10 @@ def list_custom_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListCustomJobsRequest. @@ -685,39 +759,30 @@ def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a CustomJob. Args: @@ -763,8 +828,10 @@ def delete_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteCustomJobRequest. @@ -786,18 +853,11 @@ def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -810,14 +870,15 @@ def delete_custom_job(self, # Done; return the response. return response - def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -855,8 +916,10 @@ def cancel_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelCustomJobRequest. @@ -878,28 +941,24 @@ def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -939,8 +998,10 @@ def create_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateDataLabelingJobRequest. @@ -964,30 +1025,24 @@ def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -1020,8 +1075,10 @@ def get_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetDataLabelingJobRequest. @@ -1043,30 +1100,24 @@ def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -1101,8 +1152,10 @@ def list_data_labeling_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListDataLabelingJobsRequest. @@ -1124,39 +1177,30 @@ def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1202,8 +1246,10 @@ def delete_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteDataLabelingJobRequest. @@ -1225,18 +1271,11 @@ def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1249,14 +1288,15 @@ def delete_data_labeling_job(self, # Done; return the response. return response - def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1283,8 +1323,10 @@ def cancel_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelDataLabelingJobRequest. @@ -1306,28 +1348,24 @@ def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1369,8 +1407,10 @@ def create_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateHyperparameterTuningJobRequest. @@ -1389,35 +1429,31 @@ def create_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1452,8 +1488,10 @@ def get_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetHyperparameterTuningJobRequest. @@ -1470,35 +1508,31 @@ def get_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1534,8 +1568,10 @@ def list_hyperparameter_tuning_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListHyperparameterTuningJobsRequest. @@ -1552,44 +1588,37 @@ def list_hyperparameter_tuning_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1635,8 +1664,10 @@ def delete_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteHyperparameterTuningJobRequest. @@ -1653,23 +1684,18 @@ def delete_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1682,14 +1708,15 @@ def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1729,8 +1756,10 @@ def cancel_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelHyperparameterTuningJobRequest. @@ -1747,33 +1776,31 @@ def cancel_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1818,8 +1845,10 @@ def create_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateBatchPredictionJobRequest. @@ -1838,35 +1867,31 @@ def create_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1903,8 +1928,10 @@ def get_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetBatchPredictionJobRequest. @@ -1926,30 +1953,24 @@ def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1985,8 +2006,10 @@ def list_batch_prediction_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListBatchPredictionJobsRequest. @@ -2003,44 +2026,37 @@ def list_batch_prediction_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2087,8 +2103,10 @@ def delete_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteBatchPredictionJobRequest. @@ -2105,23 +2123,18 @@ def delete_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -2134,14 +2147,15 @@ def delete_batch_prediction_job(self, # Done; return the response. return response - def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -2179,8 +2193,10 @@ def cancel_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelBatchPredictionJobRequest. @@ -2197,33 +2213,31 @@ def cancel_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_model_deployment_monitoring_job(self, - request: job_service.CreateModelDeploymentMonitoringJobRequest = None, - *, - parent: str = None, - model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_model_deployment_monitoring_job( + self, + request: job_service.CreateModelDeploymentMonitoringJobRequest = None, + *, + parent: str = None, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: r"""Creates a ModelDeploymentMonitoringJob. It will run periodically on a configured interval. @@ -2267,14 +2281,18 @@ def create_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.CreateModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.CreateModelDeploymentMonitoringJobRequest + ): request = job_service.CreateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2283,40 +2301,38 @@ def create_model_deployment_monitoring_job(self, if parent is not None: request.parent = parent if model_deployment_monitoring_job is not None: - request.model_deployment_monitoring_job = model_deployment_monitoring_job + request.model_deployment_monitoring_job = ( + model_deployment_monitoring_job + ) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def search_model_deployment_monitoring_stats_anomalies(self, - request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, - *, - model_deployment_monitoring_job: str = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager: + def search_model_deployment_monitoring_stats_anomalies( + self, + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, + *, + model_deployment_monitoring_job: str = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager: r"""Searches Model Monitoring Statistics generated within a given time window. @@ -2360,64 +2376,72 @@ def search_model_deployment_monitoring_stats_anomalies(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest): - request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + if not isinstance( + request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest + ): + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. if model_deployment_monitoring_job is not None: - request.model_deployment_monitoring_job = model_deployment_monitoring_job + request.model_deployment_monitoring_job = ( + model_deployment_monitoring_job + ) if deployed_model_id is not None: request.deployed_model_id = deployed_model_id # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.search_model_deployment_monitoring_stats_anomalies] + rpc = self._transport._wrapped_methods[ + self._transport.search_model_deployment_monitoring_stats_anomalies + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model_deployment_monitoring_job', request.model_deployment_monitoring_job), - )), + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "model_deployment_monitoring_job", + request.model_deployment_monitoring_job, + ), + ) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_model_deployment_monitoring_job(self, - request: job_service.GetModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + def get_model_deployment_monitoring_job( + self, + request: job_service.GetModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: r"""Gets a ModelDeploymentMonitoringJob. Args: @@ -2453,8 +2477,10 @@ def get_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetModelDeploymentMonitoringJobRequest. @@ -2471,35 +2497,31 @@ def get_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.get_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_deployment_monitoring_jobs(self, - request: job_service.ListModelDeploymentMonitoringJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelDeploymentMonitoringJobsPager: + def list_model_deployment_monitoring_jobs( + self, + request: job_service.ListModelDeploymentMonitoringJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelDeploymentMonitoringJobsPager: r"""Lists ModelDeploymentMonitoringJobs in a Location. Args: @@ -2535,14 +2557,18 @@ def list_model_deployment_monitoring_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListModelDeploymentMonitoringJobsRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.ListModelDeploymentMonitoringJobsRequest): + if not isinstance( + request, job_service.ListModelDeploymentMonitoringJobsRequest + ): request = job_service.ListModelDeploymentMonitoringJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2553,45 +2579,38 @@ def list_model_deployment_monitoring_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_deployment_monitoring_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_model_deployment_monitoring_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelDeploymentMonitoringJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_model_deployment_monitoring_job(self, - request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, - *, - model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def update_model_deployment_monitoring_job( + self, + request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, + *, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Updates a ModelDeploymentMonitoringJob. Args: @@ -2634,43 +2653,51 @@ def update_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.UpdateModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.UpdateModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.UpdateModelDeploymentMonitoringJobRequest + ): request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. if model_deployment_monitoring_job is not None: - request.model_deployment_monitoring_job = model_deployment_monitoring_job + request.model_deployment_monitoring_job = ( + model_deployment_monitoring_job + ) if update_mask is not None: request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.update_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.update_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model_deployment_monitoring_job.name', request.model_deployment_monitoring_job.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "model_deployment_monitoring_job.name", + request.model_deployment_monitoring_job.name, + ), + ) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -2683,14 +2710,15 @@ def update_model_deployment_monitoring_job(self, # Done; return the response. return response - def delete_model_deployment_monitoring_job(self, - request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_model_deployment_monitoring_job( + self, + request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a ModelDeploymentMonitoringJob. Args: @@ -2736,14 +2764,18 @@ def delete_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.DeleteModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.DeleteModelDeploymentMonitoringJobRequest + ): request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2754,23 +2786,18 @@ def delete_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -2783,14 +2810,15 @@ def delete_model_deployment_monitoring_job(self, # Done; return the response. return response - def pause_model_deployment_monitoring_job(self, - request: job_service.PauseModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def pause_model_deployment_monitoring_job( + self, + request: job_service.PauseModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Pauses a ModelDeploymentMonitoringJob. If the job is running, the server makes a best effort to cancel the job. Will mark ``ModelDeploymentMonitoringJob.state`` @@ -2820,14 +2848,18 @@ def pause_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.PauseModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.PauseModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.PauseModelDeploymentMonitoringJobRequest + ): request = job_service.PauseModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2838,32 +2870,30 @@ def pause_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.pause_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.pause_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def resume_model_deployment_monitoring_job(self, - request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def resume_model_deployment_monitoring_job( + self, + request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Resumes a paused ModelDeploymentMonitoringJob. It will start to run from next scheduled time. A deleted ModelDeploymentMonitoringJob can't be resumed. @@ -2892,14 +2922,18 @@ def resume_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ResumeModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.ResumeModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.ResumeModelDeploymentMonitoringJobRequest + ): request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2910,40 +2944,30 @@ def resume_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.resume_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.resume_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceClient', -) +__all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 85cb433f67..2ccecac0eb 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job @@ -23,7 +32,9 @@ from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) class ListCustomJobsPager: @@ -43,12 +54,15 @@ class ListCustomJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListCustomJobsResponse], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListCustomJobsResponse], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -82,7 +96,7 @@ def __iter__(self) -> Iterable[custom_job.CustomJob]: yield from page.custom_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListCustomJobsAsyncPager: @@ -102,12 +116,15 @@ class ListCustomJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -145,7 +162,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataLabelingJobsPager: @@ -165,12 +182,15 @@ class ListDataLabelingJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListDataLabelingJobsResponse], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListDataLabelingJobsResponse], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -204,7 +224,7 @@ def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: yield from page.data_labeling_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataLabelingJobsAsyncPager: @@ -224,12 +244,15 @@ class ListDataLabelingJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -267,7 +290,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsPager: @@ -287,12 +310,15 @@ class ListHyperparameterTuningJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -326,7 +352,7 @@ def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob yield from page.hyperparameter_tuning_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsAsyncPager: @@ -346,12 +372,17 @@ class ListHyperparameterTuningJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListHyperparameterTuningJobsResponse]], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -373,14 +404,18 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + async def pages( + self, + ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__(self) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + def __aiter__( + self, + ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: async def async_generator(): async for page in self.pages: for response in page.hyperparameter_tuning_jobs: @@ -389,7 +424,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListBatchPredictionJobsPager: @@ -409,12 +444,15 @@ class ListBatchPredictionJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListBatchPredictionJobsResponse], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListBatchPredictionJobsResponse], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -448,7 +486,7 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: yield from page.batch_prediction_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListBatchPredictionJobsAsyncPager: @@ -468,12 +506,15 @@ class ListBatchPredictionJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -511,7 +552,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class SearchModelDeploymentMonitoringStatsAnomaliesPager: @@ -531,12 +572,17 @@ class SearchModelDeploymentMonitoringStatsAnomaliesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse], - request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, - response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse + ], + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, + response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -550,7 +596,9 @@ def __init__(self, sent along with the request as metadata. """ self._method = method - self._request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + self._request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( + request + ) self._response = response self._metadata = metadata @@ -558,19 +606,23 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - def pages(self) -> Iterable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: + def pages( + self, + ) -> Iterable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = self._method(self._request, metadata=self._metadata) yield self._response - def __iter__(self) -> Iterable[gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies]: + def __iter__( + self, + ) -> Iterable[gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies]: for page in self.pages: yield from page.monitoring_stats def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: @@ -590,12 +642,20 @@ class SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]], - request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, - response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., + Awaitable[ + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse + ], + ], + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, + response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -609,7 +669,9 @@ def __init__(self, sent along with the request as metadata. """ self._method = method - self._request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + self._request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( + request + ) self._response = response self._metadata = metadata @@ -617,14 +679,22 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: + async def pages( + self, + ) -> AsyncIterable[ + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse + ]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__(self) -> AsyncIterable[gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies]: + def __aiter__( + self, + ) -> AsyncIterable[ + gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies + ]: async def async_generator(): async for page in self.pages: for response in page.monitoring_stats: @@ -633,7 +703,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelDeploymentMonitoringJobsPager: @@ -653,12 +723,15 @@ class ListModelDeploymentMonitoringJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListModelDeploymentMonitoringJobsResponse], - request: job_service.ListModelDeploymentMonitoringJobsRequest, - response: job_service.ListModelDeploymentMonitoringJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListModelDeploymentMonitoringJobsResponse], + request: job_service.ListModelDeploymentMonitoringJobsRequest, + response: job_service.ListModelDeploymentMonitoringJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -687,12 +760,14 @@ def pages(self) -> Iterable[job_service.ListModelDeploymentMonitoringJobsRespons self._response = self._method(self._request, metadata=self._metadata) yield self._response - def __iter__(self) -> Iterable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + def __iter__( + self, + ) -> Iterable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: for page in self.pages: yield from page.model_deployment_monitoring_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelDeploymentMonitoringJobsAsyncPager: @@ -712,12 +787,17 @@ class ListModelDeploymentMonitoringJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse]], - request: job_service.ListModelDeploymentMonitoringJobsRequest, - response: job_service.ListModelDeploymentMonitoringJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse] + ], + request: job_service.ListModelDeploymentMonitoringJobsRequest, + response: job_service.ListModelDeploymentMonitoringJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -739,14 +819,18 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[job_service.ListModelDeploymentMonitoringJobsResponse]: + async def pages( + self, + ) -> AsyncIterable[job_service.ListModelDeploymentMonitoringJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__(self) -> AsyncIterable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + def __aiter__( + self, + ) -> AsyncIterable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: async def async_generator(): async for page in self.pages: for response in page.model_deployment_monitoring_jobs: @@ -755,4 +839,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py index 8b5de46a7e..349bfbcdea 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] -_transport_registry['grpc'] = JobServiceGrpcTransport -_transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = JobServiceGrpcTransport +_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport __all__ = ( - 'JobServiceTransport', - 'JobServiceGrpcTransport', - 'JobServiceGrpcAsyncIOTransport', + "JobServiceTransport", + "JobServiceGrpcTransport", + "JobServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index 8ec1ad88c2..df82cc8821 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -21,22 +21,30 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -44,29 +52,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -89,8 +97,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -99,17 +107,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -118,29 +128,19 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_custom_job: gapic_v1.method.wrap_method( - self.create_custom_job, - default_timeout=5.0, - client_info=client_info, + self.create_custom_job, default_timeout=5.0, client_info=client_info, ), self.get_custom_job: gapic_v1.method.wrap_method( - self.get_custom_job, - default_timeout=5.0, - client_info=client_info, + self.get_custom_job, default_timeout=5.0, client_info=client_info, ), self.list_custom_jobs: gapic_v1.method.wrap_method( - self.list_custom_jobs, - default_timeout=5.0, - client_info=client_info, + self.list_custom_jobs, default_timeout=5.0, client_info=client_info, ), self.delete_custom_job: gapic_v1.method.wrap_method( - self.delete_custom_job, - default_timeout=5.0, - client_info=client_info, + self.delete_custom_job, default_timeout=5.0, client_info=client_info, ), self.cancel_custom_job: gapic_v1.method.wrap_method( - self.cancel_custom_job, - default_timeout=5.0, - client_info=client_info, + self.cancel_custom_job, default_timeout=5.0, client_info=client_info, ), self.create_data_labeling_job: gapic_v1.method.wrap_method( self.create_data_labeling_job, @@ -257,7 +257,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -266,258 +265,306 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_custom_job(self) -> typing.Callable[ - [job_service.CreateCustomJobRequest], - typing.Union[ - gca_custom_job.CustomJob, - typing.Awaitable[gca_custom_job.CustomJob] - ]]: + def create_custom_job( + self, + ) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] + ], + ]: raise NotImplementedError() @property - def get_custom_job(self) -> typing.Callable[ - [job_service.GetCustomJobRequest], - typing.Union[ - custom_job.CustomJob, - typing.Awaitable[custom_job.CustomJob] - ]]: + def get_custom_job( + self, + ) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], + ]: raise NotImplementedError() @property - def list_custom_jobs(self) -> typing.Callable[ - [job_service.ListCustomJobsRequest], - typing.Union[ - job_service.ListCustomJobsResponse, - typing.Awaitable[job_service.ListCustomJobsResponse] - ]]: + def list_custom_jobs( + self, + ) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_custom_job(self) -> typing.Callable[ - [job_service.DeleteCustomJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_custom_job( + self, + ) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_custom_job(self) -> typing.Callable[ - [job_service.CancelCustomJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_custom_job( + self, + ) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_data_labeling_job(self) -> typing.Callable[ - [job_service.CreateDataLabelingJobRequest], - typing.Union[ - gca_data_labeling_job.DataLabelingJob, - typing.Awaitable[gca_data_labeling_job.DataLabelingJob] - ]]: + def create_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def get_data_labeling_job(self) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], - typing.Union[ - data_labeling_job.DataLabelingJob, - typing.Awaitable[data_labeling_job.DataLabelingJob] - ]]: + def get_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def list_data_labeling_jobs(self) -> typing.Callable[ - [job_service.ListDataLabelingJobsRequest], - typing.Union[ - job_service.ListDataLabelingJobsResponse, - typing.Awaitable[job_service.ListDataLabelingJobsResponse] - ]]: + def list_data_labeling_jobs( + self, + ) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_data_labeling_job(self) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_data_labeling_job(self) -> typing.Callable[ - [job_service.CancelDataLabelingJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - typing.Union[ - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def create_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def get_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.GetHyperparameterTuningJobRequest], - typing.Union[ - hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def get_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def list_hyperparameter_tuning_jobs(self) -> typing.Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - typing.Union[ - job_service.ListHyperparameterTuningJobsResponse, - typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse] - ]]: + def list_hyperparameter_tuning_jobs( + self, + ) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_batch_prediction_job(self) -> typing.Callable[ - [job_service.CreateBatchPredictionJobRequest], - typing.Union[ - gca_batch_prediction_job.BatchPredictionJob, - typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob] - ]]: + def create_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def get_batch_prediction_job(self) -> typing.Callable[ - [job_service.GetBatchPredictionJobRequest], - typing.Union[ - batch_prediction_job.BatchPredictionJob, - typing.Awaitable[batch_prediction_job.BatchPredictionJob] - ]]: + def get_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def list_batch_prediction_jobs(self) -> typing.Callable[ - [job_service.ListBatchPredictionJobsRequest], - typing.Union[ - job_service.ListBatchPredictionJobsResponse, - typing.Awaitable[job_service.ListBatchPredictionJobsResponse] - ]]: + def list_batch_prediction_jobs( + self, + ) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_batch_prediction_job(self) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_batch_prediction_job(self) -> typing.Callable[ - [job_service.CancelBatchPredictionJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_model_deployment_monitoring_job(self) -> typing.Callable[ - [job_service.CreateModelDeploymentMonitoringJobRequest], - typing.Union[ - gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, - typing.Awaitable[gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob] - ]]: + def create_model_deployment_monitoring_job( + self, + ) -> typing.Callable[ + [job_service.CreateModelDeploymentMonitoringJobRequest], + typing.Union[ + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + typing.Awaitable[ + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob + ], + ], + ]: raise NotImplementedError() @property - def search_model_deployment_monitoring_stats_anomalies(self) -> typing.Callable[ - [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], - typing.Union[ - job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, - typing.Awaitable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse] - ]]: + def search_model_deployment_monitoring_stats_anomalies( + self, + ) -> typing.Callable[ + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + typing.Union[ + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, + typing.Awaitable[ + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse + ], + ], + ]: raise NotImplementedError() @property - def get_model_deployment_monitoring_job(self) -> typing.Callable[ - [job_service.GetModelDeploymentMonitoringJobRequest], - typing.Union[ - model_deployment_monitoring_job.ModelDeploymentMonitoringJob, - typing.Awaitable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob] - ]]: + def get_model_deployment_monitoring_job( + self, + ) -> typing.Callable[ + [job_service.GetModelDeploymentMonitoringJobRequest], + typing.Union[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + typing.Awaitable[ + model_deployment_monitoring_job.ModelDeploymentMonitoringJob + ], + ], + ]: raise NotImplementedError() @property - def list_model_deployment_monitoring_jobs(self) -> typing.Callable[ - [job_service.ListModelDeploymentMonitoringJobsRequest], - typing.Union[ - job_service.ListModelDeploymentMonitoringJobsResponse, - typing.Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse] - ]]: + def list_model_deployment_monitoring_jobs( + self, + ) -> typing.Callable[ + [job_service.ListModelDeploymentMonitoringJobsRequest], + typing.Union[ + job_service.ListModelDeploymentMonitoringJobsResponse, + typing.Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse], + ], + ]: raise NotImplementedError() @property - def update_model_deployment_monitoring_job(self) -> typing.Callable[ - [job_service.UpdateModelDeploymentMonitoringJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def update_model_deployment_monitoring_job( + self, + ) -> typing.Callable[ + [job_service.UpdateModelDeploymentMonitoringJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def delete_model_deployment_monitoring_job(self) -> typing.Callable[ - [job_service.DeleteModelDeploymentMonitoringJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_model_deployment_monitoring_job( + self, + ) -> typing.Callable[ + [job_service.DeleteModelDeploymentMonitoringJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def pause_model_deployment_monitoring_job(self) -> typing.Callable[ - [job_service.PauseModelDeploymentMonitoringJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def pause_model_deployment_monitoring_job( + self, + ) -> typing.Callable[ + [job_service.PauseModelDeploymentMonitoringJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def resume_model_deployment_monitoring_job(self) -> typing.Callable[ - [job_service.ResumeModelDeploymentMonitoringJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def resume_model_deployment_monitoring_job( + self, + ) -> typing.Callable[ + [job_service.ResumeModelDeploymentMonitoringJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() -__all__ = ( - 'JobServiceTransport', -) +__all__ = ("JobServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index 61b67d0f98..19b09b4a2f 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -18,26 +18,34 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -56,21 +64,24 @@ class JobServiceGrpcTransport(JobServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -182,13 +193,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -221,7 +234,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -239,17 +252,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_custom_job(self) -> Callable[ - [job_service.CreateCustomJobRequest], - gca_custom_job.CustomJob]: + def create_custom_job( + self, + ) -> Callable[[job_service.CreateCustomJobRequest], gca_custom_job.CustomJob]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -265,18 +276,18 @@ def create_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_custom_job' not in self._stubs: - self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs['create_custom_job'] + return self._stubs["create_custom_job"] @property - def get_custom_job(self) -> Callable[ - [job_service.GetCustomJobRequest], - custom_job.CustomJob]: + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -291,18 +302,20 @@ def get_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_custom_job' not in self._stubs: - self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs['get_custom_job'] + return self._stubs["get_custom_job"] @property - def list_custom_jobs(self) -> Callable[ - [job_service.ListCustomJobsRequest], - job_service.ListCustomJobsResponse]: + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse + ]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -317,18 +330,18 @@ def list_custom_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_custom_jobs' not in self._stubs: - self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs['list_custom_jobs'] + return self._stubs["list_custom_jobs"] @property - def delete_custom_job(self) -> Callable[ - [job_service.DeleteCustomJobRequest], - operations.Operation]: + def delete_custom_job( + self, + ) -> Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -343,18 +356,18 @@ def delete_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_custom_job' not in self._stubs: - self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_custom_job'] + return self._stubs["delete_custom_job"] @property - def cancel_custom_job(self) -> Callable[ - [job_service.CancelCustomJobRequest], - empty.Empty]: + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], empty.Empty]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -381,18 +394,21 @@ def cancel_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_custom_job' not in self._stubs: - self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_custom_job'] + return self._stubs["cancel_custom_job"] @property - def create_data_labeling_job(self) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob]: + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + gca_data_labeling_job.DataLabelingJob, + ]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -407,18 +423,20 @@ def create_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_data_labeling_job' not in self._stubs: - self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['create_data_labeling_job'] + return self._stubs["create_data_labeling_job"] @property - def get_data_labeling_job(self) -> Callable[ - [job_service.GetDataLabelingJobRequest], - data_labeling_job.DataLabelingJob]: + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob + ]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -433,18 +451,21 @@ def get_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_data_labeling_job' not in self._stubs: - self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['get_data_labeling_job'] + return self._stubs["get_data_labeling_job"] @property - def list_data_labeling_jobs(self) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse]: + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse, + ]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -459,18 +480,18 @@ def list_data_labeling_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_labeling_jobs' not in self._stubs: - self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs['list_data_labeling_jobs'] + return self._stubs["list_data_labeling_jobs"] @property - def delete_data_labeling_job(self) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], - operations.Operation]: + def delete_data_labeling_job( + self, + ) -> Callable[[job_service.DeleteDataLabelingJobRequest], operations.Operation]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -485,18 +506,18 @@ def delete_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_data_labeling_job' not in self._stubs: - self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_data_labeling_job'] + return self._stubs["delete_data_labeling_job"] @property - def cancel_data_labeling_job(self) -> Callable[ - [job_service.CancelDataLabelingJobRequest], - empty.Empty]: + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -512,18 +533,21 @@ def cancel_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_data_labeling_job' not in self._stubs: - self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_data_labeling_job'] + return self._stubs["cancel_data_labeling_job"] @property - def create_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob]: + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + ]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -539,18 +563,23 @@ def create_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_hyperparameter_tuning_job' not in self._stubs: - self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['create_hyperparameter_tuning_job'] + return self._stubs["create_hyperparameter_tuning_job"] @property - def get_hyperparameter_tuning_job(self) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob]: + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + hyperparameter_tuning_job.HyperparameterTuningJob, + ]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -565,18 +594,23 @@ def get_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_hyperparameter_tuning_job' not in self._stubs: - self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['get_hyperparameter_tuning_job'] + return self._stubs["get_hyperparameter_tuning_job"] @property - def list_hyperparameter_tuning_jobs(self) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse]: + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse, + ]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -592,18 +626,22 @@ def list_hyperparameter_tuning_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_hyperparameter_tuning_jobs' not in self._stubs: - self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs['list_hyperparameter_tuning_jobs'] + return self._stubs["list_hyperparameter_tuning_jobs"] @property - def delete_hyperparameter_tuning_job(self) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - operations.Operation]: + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation + ]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -619,18 +657,20 @@ def delete_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_hyperparameter_tuning_job' not in self._stubs: - self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_hyperparameter_tuning_job'] + return self._stubs["delete_hyperparameter_tuning_job"] @property - def cancel_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - empty.Empty]: + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[[job_service.CancelHyperparameterTuningJobRequest], empty.Empty]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -659,18 +699,23 @@ def cancel_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_hyperparameter_tuning_job' not in self._stubs: - self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_hyperparameter_tuning_job'] + return self._stubs["cancel_hyperparameter_tuning_job"] @property - def create_batch_prediction_job(self) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob]: + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + gca_batch_prediction_job.BatchPredictionJob, + ]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -686,18 +731,21 @@ def create_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_batch_prediction_job' not in self._stubs: - self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['create_batch_prediction_job'] + return self._stubs["create_batch_prediction_job"] @property - def get_batch_prediction_job(self) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob]: + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + batch_prediction_job.BatchPredictionJob, + ]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -712,18 +760,21 @@ def get_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_batch_prediction_job' not in self._stubs: - self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['get_batch_prediction_job'] + return self._stubs["get_batch_prediction_job"] @property - def list_batch_prediction_jobs(self) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse]: + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse, + ]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -738,18 +789,18 @@ def list_batch_prediction_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_batch_prediction_jobs' not in self._stubs: - self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs['list_batch_prediction_jobs'] + return self._stubs["list_batch_prediction_jobs"] @property - def delete_batch_prediction_job(self) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], - operations.Operation]: + def delete_batch_prediction_job( + self, + ) -> Callable[[job_service.DeleteBatchPredictionJobRequest], operations.Operation]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -765,18 +816,18 @@ def delete_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_batch_prediction_job' not in self._stubs: - self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_batch_prediction_job'] + return self._stubs["delete_batch_prediction_job"] @property - def cancel_batch_prediction_job(self) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], - empty.Empty]: + def cancel_batch_prediction_job( + self, + ) -> Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -802,18 +853,21 @@ def cancel_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_batch_prediction_job' not in self._stubs: - self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_batch_prediction_job'] + return self._stubs["cancel_batch_prediction_job"] @property - def create_model_deployment_monitoring_job(self) -> Callable[ - [job_service.CreateModelDeploymentMonitoringJobRequest], - gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + def create_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.CreateModelDeploymentMonitoringJobRequest], + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + ]: r"""Return a callable for the create model deployment monitoring job method over gRPC. @@ -830,18 +884,23 @@ def create_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_model_deployment_monitoring_job' not in self._stubs: - self._stubs['create_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateModelDeploymentMonitoringJob', + if "create_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "create_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateModelDeploymentMonitoringJob", request_serializer=job_service.CreateModelDeploymentMonitoringJobRequest.serialize, response_deserializer=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, ) - return self._stubs['create_model_deployment_monitoring_job'] + return self._stubs["create_model_deployment_monitoring_job"] @property - def search_model_deployment_monitoring_stats_anomalies(self) -> Callable[ - [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], - job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]: + def search_model_deployment_monitoring_stats_anomalies( + self, + ) -> Callable[ + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, + ]: r"""Return a callable for the search model deployment monitoring stats anomalies method over gRPC. @@ -858,18 +917,23 @@ def search_model_deployment_monitoring_stats_anomalies(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_model_deployment_monitoring_stats_anomalies' not in self._stubs: - self._stubs['search_model_deployment_monitoring_stats_anomalies'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/SearchModelDeploymentMonitoringStatsAnomalies', + if "search_model_deployment_monitoring_stats_anomalies" not in self._stubs: + self._stubs[ + "search_model_deployment_monitoring_stats_anomalies" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/SearchModelDeploymentMonitoringStatsAnomalies", request_serializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest.serialize, response_deserializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse.deserialize, ) - return self._stubs['search_model_deployment_monitoring_stats_anomalies'] + return self._stubs["search_model_deployment_monitoring_stats_anomalies"] @property - def get_model_deployment_monitoring_job(self) -> Callable[ - [job_service.GetModelDeploymentMonitoringJobRequest], - model_deployment_monitoring_job.ModelDeploymentMonitoringJob]: + def get_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.GetModelDeploymentMonitoringJobRequest], + model_deployment_monitoring_job.ModelDeploymentMonitoringJob, + ]: r"""Return a callable for the get model deployment monitoring job method over gRPC. @@ -885,18 +949,23 @@ def get_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_deployment_monitoring_job' not in self._stubs: - self._stubs['get_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetModelDeploymentMonitoringJob', + if "get_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "get_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetModelDeploymentMonitoringJob", request_serializer=job_service.GetModelDeploymentMonitoringJobRequest.serialize, response_deserializer=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, ) - return self._stubs['get_model_deployment_monitoring_job'] + return self._stubs["get_model_deployment_monitoring_job"] @property - def list_model_deployment_monitoring_jobs(self) -> Callable[ - [job_service.ListModelDeploymentMonitoringJobsRequest], - job_service.ListModelDeploymentMonitoringJobsResponse]: + def list_model_deployment_monitoring_jobs( + self, + ) -> Callable[ + [job_service.ListModelDeploymentMonitoringJobsRequest], + job_service.ListModelDeploymentMonitoringJobsResponse, + ]: r"""Return a callable for the list model deployment monitoring jobs method over gRPC. @@ -912,18 +981,22 @@ def list_model_deployment_monitoring_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_deployment_monitoring_jobs' not in self._stubs: - self._stubs['list_model_deployment_monitoring_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListModelDeploymentMonitoringJobs', + if "list_model_deployment_monitoring_jobs" not in self._stubs: + self._stubs[ + "list_model_deployment_monitoring_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListModelDeploymentMonitoringJobs", request_serializer=job_service.ListModelDeploymentMonitoringJobsRequest.serialize, response_deserializer=job_service.ListModelDeploymentMonitoringJobsResponse.deserialize, ) - return self._stubs['list_model_deployment_monitoring_jobs'] + return self._stubs["list_model_deployment_monitoring_jobs"] @property - def update_model_deployment_monitoring_job(self) -> Callable[ - [job_service.UpdateModelDeploymentMonitoringJobRequest], - operations.Operation]: + def update_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.UpdateModelDeploymentMonitoringJobRequest], operations.Operation + ]: r"""Return a callable for the update model deployment monitoring job method over gRPC. @@ -939,18 +1012,22 @@ def update_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model_deployment_monitoring_job' not in self._stubs: - self._stubs['update_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/UpdateModelDeploymentMonitoringJob', + if "update_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "update_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/UpdateModelDeploymentMonitoringJob", request_serializer=job_service.UpdateModelDeploymentMonitoringJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_model_deployment_monitoring_job'] + return self._stubs["update_model_deployment_monitoring_job"] @property - def delete_model_deployment_monitoring_job(self) -> Callable[ - [job_service.DeleteModelDeploymentMonitoringJobRequest], - operations.Operation]: + def delete_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.DeleteModelDeploymentMonitoringJobRequest], operations.Operation + ]: r"""Return a callable for the delete model deployment monitoring job method over gRPC. @@ -966,18 +1043,20 @@ def delete_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model_deployment_monitoring_job' not in self._stubs: - self._stubs['delete_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteModelDeploymentMonitoringJob', + if "delete_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "delete_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteModelDeploymentMonitoringJob", request_serializer=job_service.DeleteModelDeploymentMonitoringJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model_deployment_monitoring_job'] + return self._stubs["delete_model_deployment_monitoring_job"] @property - def pause_model_deployment_monitoring_job(self) -> Callable[ - [job_service.PauseModelDeploymentMonitoringJobRequest], - empty.Empty]: + def pause_model_deployment_monitoring_job( + self, + ) -> Callable[[job_service.PauseModelDeploymentMonitoringJobRequest], empty.Empty]: r"""Return a callable for the pause model deployment monitoring job method over gRPC. @@ -996,18 +1075,20 @@ def pause_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'pause_model_deployment_monitoring_job' not in self._stubs: - self._stubs['pause_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/PauseModelDeploymentMonitoringJob', + if "pause_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "pause_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/PauseModelDeploymentMonitoringJob", request_serializer=job_service.PauseModelDeploymentMonitoringJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['pause_model_deployment_monitoring_job'] + return self._stubs["pause_model_deployment_monitoring_job"] @property - def resume_model_deployment_monitoring_job(self) -> Callable[ - [job_service.ResumeModelDeploymentMonitoringJobRequest], - empty.Empty]: + def resume_model_deployment_monitoring_job( + self, + ) -> Callable[[job_service.ResumeModelDeploymentMonitoringJobRequest], empty.Empty]: r"""Return a callable for the resume model deployment monitoring job method over gRPC. @@ -1025,15 +1106,15 @@ def resume_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'resume_model_deployment_monitoring_job' not in self._stubs: - self._stubs['resume_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ResumeModelDeploymentMonitoringJob', + if "resume_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "resume_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ResumeModelDeploymentMonitoringJob", request_serializer=job_service.ResumeModelDeploymentMonitoringJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['resume_model_deployment_monitoring_job'] + return self._stubs["resume_model_deployment_monitoring_job"] -__all__ = ( - 'JobServiceGrpcTransport', -) +__all__ = ("JobServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index 3cd0904008..417746df35 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -18,27 +18,35 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -63,13 +71,15 @@ class JobServiceGrpcAsyncIOTransport(JobServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -98,22 +108,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -252,9 +264,11 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_custom_job(self) -> Callable[ - [job_service.CreateCustomJobRequest], - Awaitable[gca_custom_job.CustomJob]]: + def create_custom_job( + self, + ) -> Callable[ + [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] + ]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -270,18 +284,18 @@ def create_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_custom_job' not in self._stubs: - self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs['create_custom_job'] + return self._stubs["create_custom_job"] @property - def get_custom_job(self) -> Callable[ - [job_service.GetCustomJobRequest], - Awaitable[custom_job.CustomJob]]: + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -296,18 +310,21 @@ def get_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_custom_job' not in self._stubs: - self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs['get_custom_job'] + return self._stubs["get_custom_job"] @property - def list_custom_jobs(self) -> Callable[ - [job_service.ListCustomJobsRequest], - Awaitable[job_service.ListCustomJobsResponse]]: + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse], + ]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -322,18 +339,20 @@ def list_custom_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_custom_jobs' not in self._stubs: - self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs['list_custom_jobs'] + return self._stubs["list_custom_jobs"] @property - def delete_custom_job(self) -> Callable[ - [job_service.DeleteCustomJobRequest], - Awaitable[operations.Operation]]: + def delete_custom_job( + self, + ) -> Callable[ + [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -348,18 +367,18 @@ def delete_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_custom_job' not in self._stubs: - self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_custom_job'] + return self._stubs["delete_custom_job"] @property - def cancel_custom_job(self) -> Callable[ - [job_service.CancelCustomJobRequest], - Awaitable[empty.Empty]]: + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -386,18 +405,21 @@ def cancel_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_custom_job' not in self._stubs: - self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_custom_job'] + return self._stubs["cancel_custom_job"] @property - def create_data_labeling_job(self) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - Awaitable[gca_data_labeling_job.DataLabelingJob]]: + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob], + ]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -412,18 +434,21 @@ def create_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_data_labeling_job' not in self._stubs: - self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['create_data_labeling_job'] + return self._stubs["create_data_labeling_job"] @property - def get_data_labeling_job(self) -> Callable[ - [job_service.GetDataLabelingJobRequest], - Awaitable[data_labeling_job.DataLabelingJob]]: + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob], + ]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -438,18 +463,21 @@ def get_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_data_labeling_job' not in self._stubs: - self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['get_data_labeling_job'] + return self._stubs["get_data_labeling_job"] @property - def list_data_labeling_jobs(self) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - Awaitable[job_service.ListDataLabelingJobsResponse]]: + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse], + ]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -464,18 +492,20 @@ def list_data_labeling_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_labeling_jobs' not in self._stubs: - self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs['list_data_labeling_jobs'] + return self._stubs["list_data_labeling_jobs"] @property - def delete_data_labeling_job(self) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], - Awaitable[operations.Operation]]: + def delete_data_labeling_job( + self, + ) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -490,18 +520,18 @@ def delete_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_data_labeling_job' not in self._stubs: - self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_data_labeling_job'] + return self._stubs["delete_data_labeling_job"] @property - def cancel_data_labeling_job(self) -> Callable[ - [job_service.CancelDataLabelingJobRequest], - Awaitable[empty.Empty]]: + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -517,18 +547,21 @@ def cancel_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_data_labeling_job' not in self._stubs: - self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_data_labeling_job'] + return self._stubs["cancel_data_labeling_job"] @property - def create_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob]]: + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -544,18 +577,23 @@ def create_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_hyperparameter_tuning_job' not in self._stubs: - self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['create_hyperparameter_tuning_job'] + return self._stubs["create_hyperparameter_tuning_job"] @property - def get_hyperparameter_tuning_job(self) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob]]: + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -570,18 +608,23 @@ def get_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_hyperparameter_tuning_job' not in self._stubs: - self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['get_hyperparameter_tuning_job'] + return self._stubs["get_hyperparameter_tuning_job"] @property - def list_hyperparameter_tuning_jobs(self) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - Awaitable[job_service.ListHyperparameterTuningJobsResponse]]: + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -597,18 +640,23 @@ def list_hyperparameter_tuning_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_hyperparameter_tuning_jobs' not in self._stubs: - self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs['list_hyperparameter_tuning_jobs'] + return self._stubs["list_hyperparameter_tuning_jobs"] @property - def delete_hyperparameter_tuning_job(self) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - Awaitable[operations.Operation]]: + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -624,18 +672,22 @@ def delete_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_hyperparameter_tuning_job' not in self._stubs: - self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_hyperparameter_tuning_job'] + return self._stubs["delete_hyperparameter_tuning_job"] @property - def cancel_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - Awaitable[empty.Empty]]: + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -664,18 +716,23 @@ def cancel_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_hyperparameter_tuning_job' not in self._stubs: - self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_hyperparameter_tuning_job'] + return self._stubs["cancel_hyperparameter_tuning_job"] @property - def create_batch_prediction_job(self) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - Awaitable[gca_batch_prediction_job.BatchPredictionJob]]: + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -691,18 +748,21 @@ def create_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_batch_prediction_job' not in self._stubs: - self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['create_batch_prediction_job'] + return self._stubs["create_batch_prediction_job"] @property - def get_batch_prediction_job(self) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - Awaitable[batch_prediction_job.BatchPredictionJob]]: + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob], + ]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -717,18 +777,21 @@ def get_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_batch_prediction_job' not in self._stubs: - self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['get_batch_prediction_job'] + return self._stubs["get_batch_prediction_job"] @property - def list_batch_prediction_jobs(self) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - Awaitable[job_service.ListBatchPredictionJobsResponse]]: + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse], + ]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -743,18 +806,20 @@ def list_batch_prediction_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_batch_prediction_jobs' not in self._stubs: - self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs['list_batch_prediction_jobs'] + return self._stubs["list_batch_prediction_jobs"] @property - def delete_batch_prediction_job(self) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], - Awaitable[operations.Operation]]: + def delete_batch_prediction_job( + self, + ) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -770,18 +835,20 @@ def delete_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_batch_prediction_job' not in self._stubs: - self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_batch_prediction_job'] + return self._stubs["delete_batch_prediction_job"] @property - def cancel_batch_prediction_job(self) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], - Awaitable[empty.Empty]]: + def cancel_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -807,18 +874,21 @@ def cancel_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_batch_prediction_job' not in self._stubs: - self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_batch_prediction_job'] + return self._stubs["cancel_batch_prediction_job"] @property - def create_model_deployment_monitoring_job(self) -> Callable[ - [job_service.CreateModelDeploymentMonitoringJobRequest], - Awaitable[gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob]]: + def create_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.CreateModelDeploymentMonitoringJobRequest], + Awaitable[gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob], + ]: r"""Return a callable for the create model deployment monitoring job method over gRPC. @@ -835,18 +905,23 @@ def create_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_model_deployment_monitoring_job' not in self._stubs: - self._stubs['create_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateModelDeploymentMonitoringJob', + if "create_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "create_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateModelDeploymentMonitoringJob", request_serializer=job_service.CreateModelDeploymentMonitoringJobRequest.serialize, response_deserializer=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, ) - return self._stubs['create_model_deployment_monitoring_job'] + return self._stubs["create_model_deployment_monitoring_job"] @property - def search_model_deployment_monitoring_stats_anomalies(self) -> Callable[ - [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], - Awaitable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse]]: + def search_model_deployment_monitoring_stats_anomalies( + self, + ) -> Callable[ + [job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest], + Awaitable[job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse], + ]: r"""Return a callable for the search model deployment monitoring stats anomalies method over gRPC. @@ -863,18 +938,23 @@ def search_model_deployment_monitoring_stats_anomalies(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_model_deployment_monitoring_stats_anomalies' not in self._stubs: - self._stubs['search_model_deployment_monitoring_stats_anomalies'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/SearchModelDeploymentMonitoringStatsAnomalies', + if "search_model_deployment_monitoring_stats_anomalies" not in self._stubs: + self._stubs[ + "search_model_deployment_monitoring_stats_anomalies" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/SearchModelDeploymentMonitoringStatsAnomalies", request_serializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest.serialize, response_deserializer=job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse.deserialize, ) - return self._stubs['search_model_deployment_monitoring_stats_anomalies'] + return self._stubs["search_model_deployment_monitoring_stats_anomalies"] @property - def get_model_deployment_monitoring_job(self) -> Callable[ - [job_service.GetModelDeploymentMonitoringJobRequest], - Awaitable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob]]: + def get_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.GetModelDeploymentMonitoringJobRequest], + Awaitable[model_deployment_monitoring_job.ModelDeploymentMonitoringJob], + ]: r"""Return a callable for the get model deployment monitoring job method over gRPC. @@ -890,18 +970,23 @@ def get_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_deployment_monitoring_job' not in self._stubs: - self._stubs['get_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetModelDeploymentMonitoringJob', + if "get_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "get_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetModelDeploymentMonitoringJob", request_serializer=job_service.GetModelDeploymentMonitoringJobRequest.serialize, response_deserializer=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.deserialize, ) - return self._stubs['get_model_deployment_monitoring_job'] + return self._stubs["get_model_deployment_monitoring_job"] @property - def list_model_deployment_monitoring_jobs(self) -> Callable[ - [job_service.ListModelDeploymentMonitoringJobsRequest], - Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse]]: + def list_model_deployment_monitoring_jobs( + self, + ) -> Callable[ + [job_service.ListModelDeploymentMonitoringJobsRequest], + Awaitable[job_service.ListModelDeploymentMonitoringJobsResponse], + ]: r"""Return a callable for the list model deployment monitoring jobs method over gRPC. @@ -917,18 +1002,23 @@ def list_model_deployment_monitoring_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_deployment_monitoring_jobs' not in self._stubs: - self._stubs['list_model_deployment_monitoring_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListModelDeploymentMonitoringJobs', + if "list_model_deployment_monitoring_jobs" not in self._stubs: + self._stubs[ + "list_model_deployment_monitoring_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListModelDeploymentMonitoringJobs", request_serializer=job_service.ListModelDeploymentMonitoringJobsRequest.serialize, response_deserializer=job_service.ListModelDeploymentMonitoringJobsResponse.deserialize, ) - return self._stubs['list_model_deployment_monitoring_jobs'] + return self._stubs["list_model_deployment_monitoring_jobs"] @property - def update_model_deployment_monitoring_job(self) -> Callable[ - [job_service.UpdateModelDeploymentMonitoringJobRequest], - Awaitable[operations.Operation]]: + def update_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.UpdateModelDeploymentMonitoringJobRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the update model deployment monitoring job method over gRPC. @@ -944,18 +1034,23 @@ def update_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model_deployment_monitoring_job' not in self._stubs: - self._stubs['update_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/UpdateModelDeploymentMonitoringJob', + if "update_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "update_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/UpdateModelDeploymentMonitoringJob", request_serializer=job_service.UpdateModelDeploymentMonitoringJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_model_deployment_monitoring_job'] + return self._stubs["update_model_deployment_monitoring_job"] @property - def delete_model_deployment_monitoring_job(self) -> Callable[ - [job_service.DeleteModelDeploymentMonitoringJobRequest], - Awaitable[operations.Operation]]: + def delete_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.DeleteModelDeploymentMonitoringJobRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete model deployment monitoring job method over gRPC. @@ -971,18 +1066,22 @@ def delete_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model_deployment_monitoring_job' not in self._stubs: - self._stubs['delete_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteModelDeploymentMonitoringJob', + if "delete_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "delete_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteModelDeploymentMonitoringJob", request_serializer=job_service.DeleteModelDeploymentMonitoringJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model_deployment_monitoring_job'] + return self._stubs["delete_model_deployment_monitoring_job"] @property - def pause_model_deployment_monitoring_job(self) -> Callable[ - [job_service.PauseModelDeploymentMonitoringJobRequest], - Awaitable[empty.Empty]]: + def pause_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.PauseModelDeploymentMonitoringJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the pause model deployment monitoring job method over gRPC. @@ -1001,18 +1100,22 @@ def pause_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'pause_model_deployment_monitoring_job' not in self._stubs: - self._stubs['pause_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/PauseModelDeploymentMonitoringJob', + if "pause_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "pause_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/PauseModelDeploymentMonitoringJob", request_serializer=job_service.PauseModelDeploymentMonitoringJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['pause_model_deployment_monitoring_job'] + return self._stubs["pause_model_deployment_monitoring_job"] @property - def resume_model_deployment_monitoring_job(self) -> Callable[ - [job_service.ResumeModelDeploymentMonitoringJobRequest], - Awaitable[empty.Empty]]: + def resume_model_deployment_monitoring_job( + self, + ) -> Callable[ + [job_service.ResumeModelDeploymentMonitoringJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the resume model deployment monitoring job method over gRPC. @@ -1030,15 +1133,15 @@ def resume_model_deployment_monitoring_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'resume_model_deployment_monitoring_job' not in self._stubs: - self._stubs['resume_model_deployment_monitoring_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ResumeModelDeploymentMonitoringJob', + if "resume_model_deployment_monitoring_job" not in self._stubs: + self._stubs[ + "resume_model_deployment_monitoring_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ResumeModelDeploymentMonitoringJob", request_serializer=job_service.ResumeModelDeploymentMonitoringJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['resume_model_deployment_monitoring_job'] + return self._stubs["resume_model_deployment_monitoring_job"] -__all__ = ( - 'JobServiceGrpcAsyncIOTransport', -) +__all__ = ("JobServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py index 1f8cc4b7fb..8e9c09c94d 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import MetadataServiceAsyncClient __all__ = ( - 'MetadataServiceClient', - 'MetadataServiceAsyncClient', + "MetadataServiceClient", + "MetadataServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index d47a250882..912e00d8e1 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -71,24 +71,42 @@ class MetadataServiceAsyncClient: execution_path = staticmethod(MetadataServiceClient.execution_path) parse_execution_path = staticmethod(MetadataServiceClient.parse_execution_path) metadata_schema_path = staticmethod(MetadataServiceClient.metadata_schema_path) - parse_metadata_schema_path = staticmethod(MetadataServiceClient.parse_metadata_schema_path) + parse_metadata_schema_path = staticmethod( + MetadataServiceClient.parse_metadata_schema_path + ) metadata_store_path = staticmethod(MetadataServiceClient.metadata_store_path) - parse_metadata_store_path = staticmethod(MetadataServiceClient.parse_metadata_store_path) + parse_metadata_store_path = staticmethod( + MetadataServiceClient.parse_metadata_store_path + ) - common_billing_account_path = staticmethod(MetadataServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(MetadataServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + MetadataServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + MetadataServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(MetadataServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(MetadataServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + MetadataServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(MetadataServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(MetadataServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + MetadataServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + MetadataServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(MetadataServiceClient.common_project_path) - parse_common_project_path = staticmethod(MetadataServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + MetadataServiceClient.parse_common_project_path + ) common_location_path = staticmethod(MetadataServiceClient.common_location_path) - parse_common_location_path = staticmethod(MetadataServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + MetadataServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -131,14 +149,18 @@ def transport(self) -> MetadataServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(MetadataServiceClient).get_transport_class, type(MetadataServiceClient)) + get_transport_class = functools.partial( + type(MetadataServiceClient).get_transport_class, type(MetadataServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, MetadataServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, MetadataServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the metadata service client. Args: @@ -177,19 +199,19 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_metadata_store(self, - request: metadata_service.CreateMetadataStoreRequest = None, - *, - parent: str = None, - metadata_store: gca_metadata_store.MetadataStore = None, - metadata_store_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_metadata_store( + self, + request: metadata_service.CreateMetadataStoreRequest = None, + *, + parent: str = None, + metadata_store: gca_metadata_store.MetadataStore = None, + metadata_store_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Initializes a MetadataStore, including allocation of resources. @@ -248,8 +270,10 @@ async def create_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.CreateMetadataStoreRequest(request) @@ -274,18 +298,11 @@ async def create_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -298,14 +315,15 @@ async def create_metadata_store(self, # Done; return the response. return response - async def get_metadata_store(self, - request: metadata_service.GetMetadataStoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_store.MetadataStore: + async def get_metadata_store( + self, + request: metadata_service.GetMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_store.MetadataStore: r"""Retrieves a specific MetadataStore. Args: @@ -339,8 +357,10 @@ async def get_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.GetMetadataStoreRequest(request) @@ -361,30 +381,24 @@ async def get_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_metadata_stores(self, - request: metadata_service.ListMetadataStoresRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListMetadataStoresAsyncPager: + async def list_metadata_stores( + self, + request: metadata_service.ListMetadataStoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataStoresAsyncPager: r"""Lists MetadataStores for a Location. Args: @@ -420,8 +434,10 @@ async def list_metadata_stores(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.ListMetadataStoresRequest(request) @@ -442,39 +458,30 @@ async def list_metadata_stores(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListMetadataStoresAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_metadata_store(self, - request: metadata_service.DeleteMetadataStoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_metadata_store( + self, + request: metadata_service.DeleteMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a single MetadataStore. Args: @@ -520,8 +527,10 @@ async def delete_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.DeleteMetadataStoreRequest(request) @@ -542,18 +551,11 @@ async def delete_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -566,16 +568,17 @@ async def delete_metadata_store(self, # Done; return the response. return response - async def create_artifact(self, - request: metadata_service.CreateArtifactRequest = None, - *, - parent: str = None, - artifact: gca_artifact.Artifact = None, - artifact_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_artifact.Artifact: + async def create_artifact( + self, + request: metadata_service.CreateArtifactRequest = None, + *, + parent: str = None, + artifact: gca_artifact.Artifact = None, + artifact_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: r"""Creates an Artifact associated with a MetadataStore. Args: @@ -627,8 +630,10 @@ async def create_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.CreateArtifactRequest(request) @@ -653,30 +658,24 @@ async def create_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_artifact(self, - request: metadata_service.GetArtifactRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> artifact.Artifact: + async def get_artifact( + self, + request: metadata_service.GetArtifactRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> artifact.Artifact: r"""Retrieves a specific Artifact. Args: @@ -707,8 +706,10 @@ async def get_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.GetArtifactRequest(request) @@ -729,30 +730,24 @@ async def get_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_artifacts(self, - request: metadata_service.ListArtifactsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListArtifactsAsyncPager: + async def list_artifacts( + self, + request: metadata_service.ListArtifactsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListArtifactsAsyncPager: r"""Lists Artifacts in the MetadataStore. Args: @@ -788,8 +783,10 @@ async def list_artifacts(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.ListArtifactsRequest(request) @@ -810,40 +807,31 @@ async def list_artifacts(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListArtifactsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_artifact(self, - request: metadata_service.UpdateArtifactRequest = None, - *, - artifact: gca_artifact.Artifact = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_artifact.Artifact: + async def update_artifact( + self, + request: metadata_service.UpdateArtifactRequest = None, + *, + artifact: gca_artifact.Artifact = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: r"""Updates a stored Artifact. Args: @@ -884,8 +872,10 @@ async def update_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.UpdateArtifactRequest(request) @@ -908,32 +898,28 @@ async def update_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('artifact.name', request.artifact.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("artifact.name", request.artifact.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def create_context(self, - request: metadata_service.CreateContextRequest = None, - *, - parent: str = None, - context: gca_context.Context = None, - context_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_context.Context: + async def create_context( + self, + request: metadata_service.CreateContextRequest = None, + *, + parent: str = None, + context: gca_context.Context = None, + context_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: r"""Creates a Context associated with a MetadataStore. Args: @@ -985,8 +971,10 @@ async def create_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.CreateContextRequest(request) @@ -1011,30 +999,24 @@ async def create_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_context(self, - request: metadata_service.GetContextRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> context.Context: + async def get_context( + self, + request: metadata_service.GetContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> context.Context: r"""Retrieves a specific Context. Args: @@ -1065,8 +1047,10 @@ async def get_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.GetContextRequest(request) @@ -1087,30 +1071,24 @@ async def get_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_contexts(self, - request: metadata_service.ListContextsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListContextsAsyncPager: + async def list_contexts( + self, + request: metadata_service.ListContextsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListContextsAsyncPager: r"""Lists Contexts on the MetadataStore. Args: @@ -1146,8 +1124,10 @@ async def list_contexts(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.ListContextsRequest(request) @@ -1168,40 +1148,31 @@ async def list_contexts(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListContextsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_context(self, - request: metadata_service.UpdateContextRequest = None, - *, - context: gca_context.Context = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_context.Context: + async def update_context( + self, + request: metadata_service.UpdateContextRequest = None, + *, + context: gca_context.Context = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: r"""Updates a stored Context. Args: @@ -1241,8 +1212,10 @@ async def update_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.UpdateContextRequest(request) @@ -1265,30 +1238,26 @@ async def update_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context.name', request.context.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("context.name", request.context.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_context(self, - request: metadata_service.DeleteContextRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_context( + self, + request: metadata_service.DeleteContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a stored Context. Args: @@ -1334,8 +1303,10 @@ async def delete_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.DeleteContextRequest(request) @@ -1356,18 +1327,11 @@ async def delete_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1380,16 +1344,17 @@ async def delete_context(self, # Done; return the response. return response - async def add_context_artifacts_and_executions(self, - request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, - *, - context: str = None, - artifacts: Sequence[str] = None, - executions: Sequence[str] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: + async def add_context_artifacts_and_executions( + self, + request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, + *, + context: str = None, + artifacts: Sequence[str] = None, + executions: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: r"""Adds a set of Artifacts and Executions to a Context. If any of the Artifacts or Executions have already been added to a Context, they are simply skipped. @@ -1439,8 +1404,10 @@ async def add_context_artifacts_and_executions(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) @@ -1466,31 +1433,25 @@ async def add_context_artifacts_and_executions(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def add_context_children(self, - request: metadata_service.AddContextChildrenRequest = None, - *, - context: str = None, - child_contexts: Sequence[str] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddContextChildrenResponse: + async def add_context_children( + self, + request: metadata_service.AddContextChildrenRequest = None, + *, + context: str = None, + child_contexts: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextChildrenResponse: r"""Adds a set of Contexts as children to a parent Context. If any of the child Contexts have already been added to the parent Context, they are simply skipped. If this call would create a @@ -1534,8 +1495,10 @@ async def add_context_children(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.AddContextChildrenRequest(request) @@ -1559,30 +1522,24 @@ async def add_context_children(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def query_context_lineage_subgraph(self, - request: metadata_service.QueryContextLineageSubgraphRequest = None, - *, - context: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + async def query_context_lineage_subgraph( + self, + request: metadata_service.QueryContextLineageSubgraphRequest = None, + *, + context: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Retrieves Artifacts and Executions within the specified Context, connected by Event edges and returned as a LineageSubgraph. @@ -1624,8 +1581,10 @@ async def query_context_lineage_subgraph(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.QueryContextLineageSubgraphRequest(request) @@ -1646,32 +1605,26 @@ async def query_context_lineage_subgraph(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def create_execution(self, - request: metadata_service.CreateExecutionRequest = None, - *, - parent: str = None, - execution: gca_execution.Execution = None, - execution_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_execution.Execution: + async def create_execution( + self, + request: metadata_service.CreateExecutionRequest = None, + *, + parent: str = None, + execution: gca_execution.Execution = None, + execution_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: r"""Creates an Execution associated with a MetadataStore. Args: @@ -1723,8 +1676,10 @@ async def create_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.CreateExecutionRequest(request) @@ -1749,30 +1704,24 @@ async def create_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_execution(self, - request: metadata_service.GetExecutionRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> execution.Execution: + async def get_execution( + self, + request: metadata_service.GetExecutionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> execution.Execution: r"""Retrieves a specific Execution. Args: @@ -1803,8 +1752,10 @@ async def get_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.GetExecutionRequest(request) @@ -1825,30 +1776,24 @@ async def get_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_executions(self, - request: metadata_service.ListExecutionsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListExecutionsAsyncPager: + async def list_executions( + self, + request: metadata_service.ListExecutionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListExecutionsAsyncPager: r"""Lists Executions in the MetadataStore. Args: @@ -1884,8 +1829,10 @@ async def list_executions(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.ListExecutionsRequest(request) @@ -1906,40 +1853,31 @@ async def list_executions(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListExecutionsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_execution(self, - request: metadata_service.UpdateExecutionRequest = None, - *, - execution: gca_execution.Execution = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_execution.Execution: + async def update_execution( + self, + request: metadata_service.UpdateExecutionRequest = None, + *, + execution: gca_execution.Execution = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: r"""Updates a stored Execution. Args: @@ -1980,8 +1918,10 @@ async def update_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.UpdateExecutionRequest(request) @@ -2004,31 +1944,27 @@ async def update_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution.name', request.execution.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution.name", request.execution.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def add_execution_events(self, - request: metadata_service.AddExecutionEventsRequest = None, - *, - execution: str = None, - events: Sequence[event.Event] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddExecutionEventsResponse: + async def add_execution_events( + self, + request: metadata_service.AddExecutionEventsRequest = None, + *, + execution: str = None, + events: Sequence[event.Event] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddExecutionEventsResponse: r"""Adds Events for denoting whether each Artifact was an input or output for a given Execution. If any Events already exist between the Execution and any of the @@ -2070,8 +2006,10 @@ async def add_execution_events(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.AddExecutionEventsRequest(request) @@ -2095,30 +2033,26 @@ async def add_execution_events(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution', request.execution), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution", request.execution),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def query_execution_inputs_and_outputs(self, - request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, - *, - execution: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + async def query_execution_inputs_and_outputs( + self, + request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, + *, + execution: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Obtains the set of input and output Artifacts for this Execution, in the form of LineageSubgraph that also contains the Execution and connecting Events. @@ -2156,8 +2090,10 @@ async def query_execution_inputs_and_outputs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) @@ -2178,32 +2114,28 @@ async def query_execution_inputs_and_outputs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution', request.execution), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution", request.execution),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def create_metadata_schema(self, - request: metadata_service.CreateMetadataSchemaRequest = None, - *, - parent: str = None, - metadata_schema: gca_metadata_schema.MetadataSchema = None, - metadata_schema_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_metadata_schema.MetadataSchema: + async def create_metadata_schema( + self, + request: metadata_service.CreateMetadataSchemaRequest = None, + *, + parent: str = None, + metadata_schema: gca_metadata_schema.MetadataSchema = None, + metadata_schema_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_metadata_schema.MetadataSchema: r"""Creates an MetadataSchema. Args: @@ -2257,8 +2189,10 @@ async def create_metadata_schema(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.CreateMetadataSchemaRequest(request) @@ -2283,30 +2217,24 @@ async def create_metadata_schema(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_metadata_schema(self, - request: metadata_service.GetMetadataSchemaRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_schema.MetadataSchema: + async def get_metadata_schema( + self, + request: metadata_service.GetMetadataSchemaRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_schema.MetadataSchema: r"""Retrieves a specific MetadataSchema. Args: @@ -2337,8 +2265,10 @@ async def get_metadata_schema(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.GetMetadataSchemaRequest(request) @@ -2359,30 +2289,24 @@ async def get_metadata_schema(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_metadata_schemas(self, - request: metadata_service.ListMetadataSchemasRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListMetadataSchemasAsyncPager: + async def list_metadata_schemas( + self, + request: metadata_service.ListMetadataSchemasRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataSchemasAsyncPager: r"""Lists MetadataSchemas. Args: @@ -2419,8 +2343,10 @@ async def list_metadata_schemas(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.ListMetadataSchemasRequest(request) @@ -2441,47 +2367,30 @@ async def list_metadata_schemas(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListMetadataSchemasAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MetadataServiceAsyncClient', -) +__all__ = ("MetadataServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index e1fcc67567..705ac60a12 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -67,13 +67,14 @@ class MetadataServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[MetadataServiceTransport]] - _transport_registry['grpc'] = MetadataServiceGrpcTransport - _transport_registry['grpc_asyncio'] = MetadataServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[MetadataServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MetadataServiceTransport]] + _transport_registry["grpc"] = MetadataServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MetadataServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MetadataServiceTransport]: """Return an appropriate transport class. Args: @@ -124,7 +125,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -159,9 +160,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MetadataServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -176,121 +176,172 @@ def transport(self) -> MetadataServiceTransport: return self._transport @staticmethod - def artifact_path(project: str,location: str,metadata_store: str,artifact: str,) -> str: + def artifact_path( + project: str, location: str, metadata_store: str, artifact: str, + ) -> str: """Return a fully-qualified artifact string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format(project=project, location=location, metadata_store=metadata_store, artifact=artifact, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) @staticmethod - def parse_artifact_path(path: str) -> Dict[str,str]: + def parse_artifact_path(path: str) -> Dict[str, str]: """Parse a artifact path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/artifacts/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/artifacts/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def context_path(project: str,location: str,metadata_store: str,context: str,) -> str: + def context_path( + project: str, location: str, metadata_store: str, context: str, + ) -> str: """Return a fully-qualified context string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format(project=project, location=location, metadata_store=metadata_store, context=context, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) @staticmethod - def parse_context_path(path: str) -> Dict[str,str]: + def parse_context_path(path: str) -> Dict[str, str]: """Parse a context path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def execution_path(project: str,location: str,metadata_store: str,execution: str,) -> str: + def execution_path( + project: str, location: str, metadata_store: str, execution: str, + ) -> str: """Return a fully-qualified execution string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format(project=project, location=location, metadata_store=metadata_store, execution=execution, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) @staticmethod - def parse_execution_path(path: str) -> Dict[str,str]: + def parse_execution_path(path: str) -> Dict[str, str]: """Parse a execution path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/executions/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/executions/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def metadata_schema_path(project: str,location: str,metadata_store: str,metadata_schema: str,) -> str: + def metadata_schema_path( + project: str, location: str, metadata_store: str, metadata_schema: str, + ) -> str: """Return a fully-qualified metadata_schema string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format(project=project, location=location, metadata_store=metadata_store, metadata_schema=metadata_schema, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format( + project=project, + location=location, + metadata_store=metadata_store, + metadata_schema=metadata_schema, + ) @staticmethod - def parse_metadata_schema_path(path: str) -> Dict[str,str]: + def parse_metadata_schema_path(path: str) -> Dict[str, str]: """Parse a metadata_schema path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/metadataSchemas/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/metadataSchemas/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def metadata_store_path(project: str,location: str,metadata_store: str,) -> str: + def metadata_store_path(project: str, location: str, metadata_store: str,) -> str: """Return a fully-qualified metadata_store string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format(project=project, location=location, metadata_store=metadata_store, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format( + project=project, location=location, metadata_store=metadata_store, + ) @staticmethod - def parse_metadata_store_path(path: str) -> Dict[str,str]: + def parse_metadata_store_path(path: str) -> Dict[str, str]: """Parse a metadata_store path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MetadataServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MetadataServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the metadata service client. Args: @@ -334,7 +385,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -344,7 +397,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -356,7 +411,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -368,8 +425,10 @@ def __init__(self, *, if isinstance(transport, MetadataServiceTransport): # transport is a MetadataServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -388,16 +447,17 @@ def __init__(self, *, client_info=client_info, ) - def create_metadata_store(self, - request: metadata_service.CreateMetadataStoreRequest = None, - *, - parent: str = None, - metadata_store: gca_metadata_store.MetadataStore = None, - metadata_store_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_metadata_store( + self, + request: metadata_service.CreateMetadataStoreRequest = None, + *, + parent: str = None, + metadata_store: gca_metadata_store.MetadataStore = None, + metadata_store_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Initializes a MetadataStore, including allocation of resources. @@ -456,8 +516,10 @@ def create_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateMetadataStoreRequest. @@ -483,18 +545,11 @@ def create_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -507,14 +562,15 @@ def create_metadata_store(self, # Done; return the response. return response - def get_metadata_store(self, - request: metadata_service.GetMetadataStoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_store.MetadataStore: + def get_metadata_store( + self, + request: metadata_service.GetMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_store.MetadataStore: r"""Retrieves a specific MetadataStore. Args: @@ -548,8 +604,10 @@ def get_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetMetadataStoreRequest. @@ -571,30 +629,24 @@ def get_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_metadata_stores(self, - request: metadata_service.ListMetadataStoresRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListMetadataStoresPager: + def list_metadata_stores( + self, + request: metadata_service.ListMetadataStoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataStoresPager: r"""Lists MetadataStores for a Location. Args: @@ -630,8 +682,10 @@ def list_metadata_stores(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListMetadataStoresRequest. @@ -653,39 +707,30 @@ def list_metadata_stores(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListMetadataStoresPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_metadata_store(self, - request: metadata_service.DeleteMetadataStoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_metadata_store( + self, + request: metadata_service.DeleteMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a single MetadataStore. Args: @@ -731,8 +776,10 @@ def delete_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.DeleteMetadataStoreRequest. @@ -754,18 +801,11 @@ def delete_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -778,16 +818,17 @@ def delete_metadata_store(self, # Done; return the response. return response - def create_artifact(self, - request: metadata_service.CreateArtifactRequest = None, - *, - parent: str = None, - artifact: gca_artifact.Artifact = None, - artifact_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_artifact.Artifact: + def create_artifact( + self, + request: metadata_service.CreateArtifactRequest = None, + *, + parent: str = None, + artifact: gca_artifact.Artifact = None, + artifact_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: r"""Creates an Artifact associated with a MetadataStore. Args: @@ -839,8 +880,10 @@ def create_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateArtifactRequest. @@ -866,30 +909,24 @@ def create_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_artifact(self, - request: metadata_service.GetArtifactRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> artifact.Artifact: + def get_artifact( + self, + request: metadata_service.GetArtifactRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> artifact.Artifact: r"""Retrieves a specific Artifact. Args: @@ -920,8 +957,10 @@ def get_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetArtifactRequest. @@ -943,30 +982,24 @@ def get_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_artifacts(self, - request: metadata_service.ListArtifactsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListArtifactsPager: + def list_artifacts( + self, + request: metadata_service.ListArtifactsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListArtifactsPager: r"""Lists Artifacts in the MetadataStore. Args: @@ -1002,8 +1035,10 @@ def list_artifacts(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListArtifactsRequest. @@ -1025,40 +1060,31 @@ def list_artifacts(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListArtifactsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_artifact(self, - request: metadata_service.UpdateArtifactRequest = None, - *, - artifact: gca_artifact.Artifact = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_artifact.Artifact: + def update_artifact( + self, + request: metadata_service.UpdateArtifactRequest = None, + *, + artifact: gca_artifact.Artifact = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: r"""Updates a stored Artifact. Args: @@ -1099,8 +1125,10 @@ def update_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.UpdateArtifactRequest. @@ -1124,32 +1152,28 @@ def update_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('artifact.name', request.artifact.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("artifact.name", request.artifact.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def create_context(self, - request: metadata_service.CreateContextRequest = None, - *, - parent: str = None, - context: gca_context.Context = None, - context_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_context.Context: + def create_context( + self, + request: metadata_service.CreateContextRequest = None, + *, + parent: str = None, + context: gca_context.Context = None, + context_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: r"""Creates a Context associated with a MetadataStore. Args: @@ -1201,8 +1225,10 @@ def create_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateContextRequest. @@ -1228,30 +1254,24 @@ def create_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_context(self, - request: metadata_service.GetContextRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> context.Context: + def get_context( + self, + request: metadata_service.GetContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> context.Context: r"""Retrieves a specific Context. Args: @@ -1282,8 +1302,10 @@ def get_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetContextRequest. @@ -1305,30 +1327,24 @@ def get_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_contexts(self, - request: metadata_service.ListContextsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListContextsPager: + def list_contexts( + self, + request: metadata_service.ListContextsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListContextsPager: r"""Lists Contexts on the MetadataStore. Args: @@ -1364,8 +1380,10 @@ def list_contexts(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListContextsRequest. @@ -1387,40 +1405,31 @@ def list_contexts(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListContextsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_context(self, - request: metadata_service.UpdateContextRequest = None, - *, - context: gca_context.Context = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_context.Context: + def update_context( + self, + request: metadata_service.UpdateContextRequest = None, + *, + context: gca_context.Context = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: r"""Updates a stored Context. Args: @@ -1460,8 +1469,10 @@ def update_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.UpdateContextRequest. @@ -1485,30 +1496,26 @@ def update_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context.name', request.context.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("context.name", request.context.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_context(self, - request: metadata_service.DeleteContextRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_context( + self, + request: metadata_service.DeleteContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a stored Context. Args: @@ -1554,8 +1561,10 @@ def delete_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.DeleteContextRequest. @@ -1577,18 +1586,11 @@ def delete_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1601,16 +1603,17 @@ def delete_context(self, # Done; return the response. return response - def add_context_artifacts_and_executions(self, - request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, - *, - context: str = None, - artifacts: Sequence[str] = None, - executions: Sequence[str] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: + def add_context_artifacts_and_executions( + self, + request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, + *, + context: str = None, + artifacts: Sequence[str] = None, + executions: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: r"""Adds a set of Artifacts and Executions to a Context. If any of the Artifacts or Executions have already been added to a Context, they are simply skipped. @@ -1660,14 +1663,18 @@ def add_context_artifacts_and_executions(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.AddContextArtifactsAndExecutionsRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, metadata_service.AddContextArtifactsAndExecutionsRequest): + if not isinstance( + request, metadata_service.AddContextArtifactsAndExecutionsRequest + ): request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1682,36 +1689,32 @@ def add_context_artifacts_and_executions(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.add_context_artifacts_and_executions] + rpc = self._transport._wrapped_methods[ + self._transport.add_context_artifacts_and_executions + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def add_context_children(self, - request: metadata_service.AddContextChildrenRequest = None, - *, - context: str = None, - child_contexts: Sequence[str] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddContextChildrenResponse: + def add_context_children( + self, + request: metadata_service.AddContextChildrenRequest = None, + *, + context: str = None, + child_contexts: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextChildrenResponse: r"""Adds a set of Contexts as children to a parent Context. If any of the child Contexts have already been added to the parent Context, they are simply skipped. If this call would create a @@ -1755,8 +1758,10 @@ def add_context_children(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.AddContextChildrenRequest. @@ -1780,30 +1785,24 @@ def add_context_children(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def query_context_lineage_subgraph(self, - request: metadata_service.QueryContextLineageSubgraphRequest = None, - *, - context: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + def query_context_lineage_subgraph( + self, + request: metadata_service.QueryContextLineageSubgraphRequest = None, + *, + context: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Retrieves Artifacts and Executions within the specified Context, connected by Event edges and returned as a LineageSubgraph. @@ -1845,8 +1844,10 @@ def query_context_lineage_subgraph(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.QueryContextLineageSubgraphRequest. @@ -1863,37 +1864,33 @@ def query_context_lineage_subgraph(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.query_context_lineage_subgraph] + rpc = self._transport._wrapped_methods[ + self._transport.query_context_lineage_subgraph + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def create_execution(self, - request: metadata_service.CreateExecutionRequest = None, - *, - parent: str = None, - execution: gca_execution.Execution = None, - execution_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_execution.Execution: + def create_execution( + self, + request: metadata_service.CreateExecutionRequest = None, + *, + parent: str = None, + execution: gca_execution.Execution = None, + execution_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: r"""Creates an Execution associated with a MetadataStore. Args: @@ -1945,8 +1942,10 @@ def create_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateExecutionRequest. @@ -1972,30 +1971,24 @@ def create_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_execution(self, - request: metadata_service.GetExecutionRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> execution.Execution: + def get_execution( + self, + request: metadata_service.GetExecutionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> execution.Execution: r"""Retrieves a specific Execution. Args: @@ -2026,8 +2019,10 @@ def get_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetExecutionRequest. @@ -2049,30 +2044,24 @@ def get_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_executions(self, - request: metadata_service.ListExecutionsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListExecutionsPager: + def list_executions( + self, + request: metadata_service.ListExecutionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListExecutionsPager: r"""Lists Executions in the MetadataStore. Args: @@ -2108,8 +2097,10 @@ def list_executions(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListExecutionsRequest. @@ -2131,40 +2122,31 @@ def list_executions(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListExecutionsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_execution(self, - request: metadata_service.UpdateExecutionRequest = None, - *, - execution: gca_execution.Execution = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_execution.Execution: + def update_execution( + self, + request: metadata_service.UpdateExecutionRequest = None, + *, + execution: gca_execution.Execution = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: r"""Updates a stored Execution. Args: @@ -2205,8 +2187,10 @@ def update_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.UpdateExecutionRequest. @@ -2230,31 +2214,27 @@ def update_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution.name', request.execution.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution.name", request.execution.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def add_execution_events(self, - request: metadata_service.AddExecutionEventsRequest = None, - *, - execution: str = None, - events: Sequence[event.Event] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddExecutionEventsResponse: + def add_execution_events( + self, + request: metadata_service.AddExecutionEventsRequest = None, + *, + execution: str = None, + events: Sequence[event.Event] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddExecutionEventsResponse: r"""Adds Events for denoting whether each Artifact was an input or output for a given Execution. If any Events already exist between the Execution and any of the @@ -2296,8 +2276,10 @@ def add_execution_events(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.AddExecutionEventsRequest. @@ -2321,30 +2303,26 @@ def add_execution_events(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution', request.execution), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution", request.execution),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def query_execution_inputs_and_outputs(self, - request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, - *, - execution: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + def query_execution_inputs_and_outputs( + self, + request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, + *, + execution: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Obtains the set of input and output Artifacts for this Execution, in the form of LineageSubgraph that also contains the Execution and connecting Events. @@ -2382,14 +2360,18 @@ def query_execution_inputs_and_outputs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.QueryExecutionInputsAndOutputsRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, metadata_service.QueryExecutionInputsAndOutputsRequest): + if not isinstance( + request, metadata_service.QueryExecutionInputsAndOutputsRequest + ): request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2400,37 +2382,35 @@ def query_execution_inputs_and_outputs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.query_execution_inputs_and_outputs] + rpc = self._transport._wrapped_methods[ + self._transport.query_execution_inputs_and_outputs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution', request.execution), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution", request.execution),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def create_metadata_schema(self, - request: metadata_service.CreateMetadataSchemaRequest = None, - *, - parent: str = None, - metadata_schema: gca_metadata_schema.MetadataSchema = None, - metadata_schema_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_metadata_schema.MetadataSchema: + def create_metadata_schema( + self, + request: metadata_service.CreateMetadataSchemaRequest = None, + *, + parent: str = None, + metadata_schema: gca_metadata_schema.MetadataSchema = None, + metadata_schema_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_metadata_schema.MetadataSchema: r"""Creates an MetadataSchema. Args: @@ -2484,8 +2464,10 @@ def create_metadata_schema(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateMetadataSchemaRequest. @@ -2511,30 +2493,24 @@ def create_metadata_schema(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_metadata_schema(self, - request: metadata_service.GetMetadataSchemaRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_schema.MetadataSchema: + def get_metadata_schema( + self, + request: metadata_service.GetMetadataSchemaRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_schema.MetadataSchema: r"""Retrieves a specific MetadataSchema. Args: @@ -2565,8 +2541,10 @@ def get_metadata_schema(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetMetadataSchemaRequest. @@ -2588,30 +2566,24 @@ def get_metadata_schema(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_metadata_schemas(self, - request: metadata_service.ListMetadataSchemasRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListMetadataSchemasPager: + def list_metadata_schemas( + self, + request: metadata_service.ListMetadataSchemasRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataSchemasPager: r"""Lists MetadataSchemas. Args: @@ -2648,8 +2620,10 @@ def list_metadata_schemas(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListMetadataSchemasRequest. @@ -2671,47 +2645,30 @@ def list_metadata_schemas(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListMetadataSchemasPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MetadataServiceClient', -) +__all__ = ("MetadataServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py index da04d2882f..979c99e4e8 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import artifact from google.cloud.aiplatform_v1beta1.types import context @@ -42,12 +51,15 @@ class ListMetadataStoresPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., metadata_service.ListMetadataStoresResponse], - request: metadata_service.ListMetadataStoresRequest, - response: metadata_service.ListMetadataStoresResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., metadata_service.ListMetadataStoresResponse], + request: metadata_service.ListMetadataStoresRequest, + response: metadata_service.ListMetadataStoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -81,7 +93,7 @@ def __iter__(self) -> Iterable[metadata_store.MetadataStore]: yield from page.metadata_stores def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListMetadataStoresAsyncPager: @@ -101,12 +113,15 @@ class ListMetadataStoresAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[metadata_service.ListMetadataStoresResponse]], - request: metadata_service.ListMetadataStoresRequest, - response: metadata_service.ListMetadataStoresResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[metadata_service.ListMetadataStoresResponse]], + request: metadata_service.ListMetadataStoresRequest, + response: metadata_service.ListMetadataStoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -144,7 +159,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListArtifactsPager: @@ -164,12 +179,15 @@ class ListArtifactsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., metadata_service.ListArtifactsResponse], - request: metadata_service.ListArtifactsRequest, - response: metadata_service.ListArtifactsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., metadata_service.ListArtifactsResponse], + request: metadata_service.ListArtifactsRequest, + response: metadata_service.ListArtifactsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -203,7 +221,7 @@ def __iter__(self) -> Iterable[artifact.Artifact]: yield from page.artifacts def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListArtifactsAsyncPager: @@ -223,12 +241,15 @@ class ListArtifactsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[metadata_service.ListArtifactsResponse]], - request: metadata_service.ListArtifactsRequest, - response: metadata_service.ListArtifactsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[metadata_service.ListArtifactsResponse]], + request: metadata_service.ListArtifactsRequest, + response: metadata_service.ListArtifactsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -266,7 +287,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListContextsPager: @@ -286,12 +307,15 @@ class ListContextsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., metadata_service.ListContextsResponse], - request: metadata_service.ListContextsRequest, - response: metadata_service.ListContextsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., metadata_service.ListContextsResponse], + request: metadata_service.ListContextsRequest, + response: metadata_service.ListContextsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -325,7 +349,7 @@ def __iter__(self) -> Iterable[context.Context]: yield from page.contexts def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListContextsAsyncPager: @@ -345,12 +369,15 @@ class ListContextsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[metadata_service.ListContextsResponse]], - request: metadata_service.ListContextsRequest, - response: metadata_service.ListContextsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[metadata_service.ListContextsResponse]], + request: metadata_service.ListContextsRequest, + response: metadata_service.ListContextsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -388,7 +415,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListExecutionsPager: @@ -408,12 +435,15 @@ class ListExecutionsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., metadata_service.ListExecutionsResponse], - request: metadata_service.ListExecutionsRequest, - response: metadata_service.ListExecutionsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., metadata_service.ListExecutionsResponse], + request: metadata_service.ListExecutionsRequest, + response: metadata_service.ListExecutionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -447,7 +477,7 @@ def __iter__(self) -> Iterable[execution.Execution]: yield from page.executions def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListExecutionsAsyncPager: @@ -467,12 +497,15 @@ class ListExecutionsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[metadata_service.ListExecutionsResponse]], - request: metadata_service.ListExecutionsRequest, - response: metadata_service.ListExecutionsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[metadata_service.ListExecutionsResponse]], + request: metadata_service.ListExecutionsRequest, + response: metadata_service.ListExecutionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -510,7 +543,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListMetadataSchemasPager: @@ -530,12 +563,15 @@ class ListMetadataSchemasPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., metadata_service.ListMetadataSchemasResponse], - request: metadata_service.ListMetadataSchemasRequest, - response: metadata_service.ListMetadataSchemasResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., metadata_service.ListMetadataSchemasResponse], + request: metadata_service.ListMetadataSchemasRequest, + response: metadata_service.ListMetadataSchemasResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -569,7 +605,7 @@ def __iter__(self) -> Iterable[metadata_schema.MetadataSchema]: yield from page.metadata_schemas def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListMetadataSchemasAsyncPager: @@ -589,12 +625,15 @@ class ListMetadataSchemasAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[metadata_service.ListMetadataSchemasResponse]], - request: metadata_service.ListMetadataSchemasRequest, - response: metadata_service.ListMetadataSchemasResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[metadata_service.ListMetadataSchemasResponse]], + request: metadata_service.ListMetadataSchemasRequest, + response: metadata_service.ListMetadataSchemasResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -616,7 +655,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[metadata_service.ListMetadataSchemasResponse]: + async def pages( + self, + ) -> AsyncIterable[metadata_service.ListMetadataSchemasResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -632,4 +673,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py index 67031880cd..a01e7ca986 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[MetadataServiceTransport]] -_transport_registry['grpc'] = MetadataServiceGrpcTransport -_transport_registry['grpc_asyncio'] = MetadataServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = MetadataServiceGrpcTransport +_transport_registry["grpc_asyncio"] = MetadataServiceGrpcAsyncIOTransport __all__ = ( - 'MetadataServiceTransport', - 'MetadataServiceGrpcTransport', - 'MetadataServiceGrpcAsyncIOTransport', + "MetadataServiceTransport", + "MetadataServiceGrpcTransport", + "MetadataServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py index 76ef934c98..4991d3a471 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -43,29 +43,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class MetadataServiceTransport(abc.ABC): """Abstract transport class for MetadataService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -88,8 +88,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -98,17 +98,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -122,9 +124,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_metadata_store: gapic_v1.method.wrap_method( - self.get_metadata_store, - default_timeout=None, - client_info=client_info, + self.get_metadata_store, default_timeout=None, client_info=client_info, ), self.list_metadata_stores: gapic_v1.method.wrap_method( self.list_metadata_stores, @@ -137,49 +137,31 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.create_artifact: gapic_v1.method.wrap_method( - self.create_artifact, - default_timeout=None, - client_info=client_info, + self.create_artifact, default_timeout=None, client_info=client_info, ), self.get_artifact: gapic_v1.method.wrap_method( - self.get_artifact, - default_timeout=None, - client_info=client_info, + self.get_artifact, default_timeout=None, client_info=client_info, ), self.list_artifacts: gapic_v1.method.wrap_method( - self.list_artifacts, - default_timeout=None, - client_info=client_info, + self.list_artifacts, default_timeout=None, client_info=client_info, ), self.update_artifact: gapic_v1.method.wrap_method( - self.update_artifact, - default_timeout=None, - client_info=client_info, + self.update_artifact, default_timeout=None, client_info=client_info, ), self.create_context: gapic_v1.method.wrap_method( - self.create_context, - default_timeout=None, - client_info=client_info, + self.create_context, default_timeout=None, client_info=client_info, ), self.get_context: gapic_v1.method.wrap_method( - self.get_context, - default_timeout=None, - client_info=client_info, + self.get_context, default_timeout=None, client_info=client_info, ), self.list_contexts: gapic_v1.method.wrap_method( - self.list_contexts, - default_timeout=None, - client_info=client_info, + self.list_contexts, default_timeout=None, client_info=client_info, ), self.update_context: gapic_v1.method.wrap_method( - self.update_context, - default_timeout=None, - client_info=client_info, + self.update_context, default_timeout=None, client_info=client_info, ), self.delete_context: gapic_v1.method.wrap_method( - self.delete_context, - default_timeout=None, - client_info=client_info, + self.delete_context, default_timeout=None, client_info=client_info, ), self.add_context_artifacts_and_executions: gapic_v1.method.wrap_method( self.add_context_artifacts_and_executions, @@ -197,24 +179,16 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.create_execution: gapic_v1.method.wrap_method( - self.create_execution, - default_timeout=None, - client_info=client_info, + self.create_execution, default_timeout=None, client_info=client_info, ), self.get_execution: gapic_v1.method.wrap_method( - self.get_execution, - default_timeout=None, - client_info=client_info, + self.get_execution, default_timeout=None, client_info=client_info, ), self.list_executions: gapic_v1.method.wrap_method( - self.list_executions, - default_timeout=None, - client_info=client_info, + self.list_executions, default_timeout=None, client_info=client_info, ), self.update_execution: gapic_v1.method.wrap_method( - self.update_execution, - default_timeout=None, - client_info=client_info, + self.update_execution, default_timeout=None, client_info=client_info, ), self.add_execution_events: gapic_v1.method.wrap_method( self.add_execution_events, @@ -232,16 +206,13 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_metadata_schema: gapic_v1.method.wrap_method( - self.get_metadata_schema, - default_timeout=None, - client_info=client_info, + self.get_metadata_schema, default_timeout=None, client_info=client_info, ), self.list_metadata_schemas: gapic_v1.method.wrap_method( self.list_metadata_schemas, default_timeout=None, client_info=client_info, ), - } @property @@ -250,231 +221,271 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_metadata_store(self) -> typing.Callable[ - [metadata_service.CreateMetadataStoreRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_metadata_store( + self, + ) -> typing.Callable[ + [metadata_service.CreateMetadataStoreRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_metadata_store(self) -> typing.Callable[ - [metadata_service.GetMetadataStoreRequest], - typing.Union[ - metadata_store.MetadataStore, - typing.Awaitable[metadata_store.MetadataStore] - ]]: + def get_metadata_store( + self, + ) -> typing.Callable[ + [metadata_service.GetMetadataStoreRequest], + typing.Union[ + metadata_store.MetadataStore, typing.Awaitable[metadata_store.MetadataStore] + ], + ]: raise NotImplementedError() @property - def list_metadata_stores(self) -> typing.Callable[ - [metadata_service.ListMetadataStoresRequest], - typing.Union[ - metadata_service.ListMetadataStoresResponse, - typing.Awaitable[metadata_service.ListMetadataStoresResponse] - ]]: + def list_metadata_stores( + self, + ) -> typing.Callable[ + [metadata_service.ListMetadataStoresRequest], + typing.Union[ + metadata_service.ListMetadataStoresResponse, + typing.Awaitable[metadata_service.ListMetadataStoresResponse], + ], + ]: raise NotImplementedError() @property - def delete_metadata_store(self) -> typing.Callable[ - [metadata_service.DeleteMetadataStoreRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_metadata_store( + self, + ) -> typing.Callable[ + [metadata_service.DeleteMetadataStoreRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def create_artifact(self) -> typing.Callable[ - [metadata_service.CreateArtifactRequest], - typing.Union[ - gca_artifact.Artifact, - typing.Awaitable[gca_artifact.Artifact] - ]]: + def create_artifact( + self, + ) -> typing.Callable[ + [metadata_service.CreateArtifactRequest], + typing.Union[gca_artifact.Artifact, typing.Awaitable[gca_artifact.Artifact]], + ]: raise NotImplementedError() @property - def get_artifact(self) -> typing.Callable[ - [metadata_service.GetArtifactRequest], - typing.Union[ - artifact.Artifact, - typing.Awaitable[artifact.Artifact] - ]]: + def get_artifact( + self, + ) -> typing.Callable[ + [metadata_service.GetArtifactRequest], + typing.Union[artifact.Artifact, typing.Awaitable[artifact.Artifact]], + ]: raise NotImplementedError() @property - def list_artifacts(self) -> typing.Callable[ - [metadata_service.ListArtifactsRequest], - typing.Union[ - metadata_service.ListArtifactsResponse, - typing.Awaitable[metadata_service.ListArtifactsResponse] - ]]: + def list_artifacts( + self, + ) -> typing.Callable[ + [metadata_service.ListArtifactsRequest], + typing.Union[ + metadata_service.ListArtifactsResponse, + typing.Awaitable[metadata_service.ListArtifactsResponse], + ], + ]: raise NotImplementedError() @property - def update_artifact(self) -> typing.Callable[ - [metadata_service.UpdateArtifactRequest], - typing.Union[ - gca_artifact.Artifact, - typing.Awaitable[gca_artifact.Artifact] - ]]: + def update_artifact( + self, + ) -> typing.Callable[ + [metadata_service.UpdateArtifactRequest], + typing.Union[gca_artifact.Artifact, typing.Awaitable[gca_artifact.Artifact]], + ]: raise NotImplementedError() @property - def create_context(self) -> typing.Callable[ - [metadata_service.CreateContextRequest], - typing.Union[ - gca_context.Context, - typing.Awaitable[gca_context.Context] - ]]: + def create_context( + self, + ) -> typing.Callable[ + [metadata_service.CreateContextRequest], + typing.Union[gca_context.Context, typing.Awaitable[gca_context.Context]], + ]: raise NotImplementedError() @property - def get_context(self) -> typing.Callable[ - [metadata_service.GetContextRequest], - typing.Union[ - context.Context, - typing.Awaitable[context.Context] - ]]: + def get_context( + self, + ) -> typing.Callable[ + [metadata_service.GetContextRequest], + typing.Union[context.Context, typing.Awaitable[context.Context]], + ]: raise NotImplementedError() @property - def list_contexts(self) -> typing.Callable[ - [metadata_service.ListContextsRequest], - typing.Union[ - metadata_service.ListContextsResponse, - typing.Awaitable[metadata_service.ListContextsResponse] - ]]: + def list_contexts( + self, + ) -> typing.Callable[ + [metadata_service.ListContextsRequest], + typing.Union[ + metadata_service.ListContextsResponse, + typing.Awaitable[metadata_service.ListContextsResponse], + ], + ]: raise NotImplementedError() @property - def update_context(self) -> typing.Callable[ - [metadata_service.UpdateContextRequest], - typing.Union[ - gca_context.Context, - typing.Awaitable[gca_context.Context] - ]]: + def update_context( + self, + ) -> typing.Callable[ + [metadata_service.UpdateContextRequest], + typing.Union[gca_context.Context, typing.Awaitable[gca_context.Context]], + ]: raise NotImplementedError() @property - def delete_context(self) -> typing.Callable[ - [metadata_service.DeleteContextRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_context( + self, + ) -> typing.Callable[ + [metadata_service.DeleteContextRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def add_context_artifacts_and_executions(self) -> typing.Callable[ - [metadata_service.AddContextArtifactsAndExecutionsRequest], - typing.Union[ - metadata_service.AddContextArtifactsAndExecutionsResponse, - typing.Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse] - ]]: + def add_context_artifacts_and_executions( + self, + ) -> typing.Callable[ + [metadata_service.AddContextArtifactsAndExecutionsRequest], + typing.Union[ + metadata_service.AddContextArtifactsAndExecutionsResponse, + typing.Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse], + ], + ]: raise NotImplementedError() @property - def add_context_children(self) -> typing.Callable[ - [metadata_service.AddContextChildrenRequest], - typing.Union[ - metadata_service.AddContextChildrenResponse, - typing.Awaitable[metadata_service.AddContextChildrenResponse] - ]]: + def add_context_children( + self, + ) -> typing.Callable[ + [metadata_service.AddContextChildrenRequest], + typing.Union[ + metadata_service.AddContextChildrenResponse, + typing.Awaitable[metadata_service.AddContextChildrenResponse], + ], + ]: raise NotImplementedError() @property - def query_context_lineage_subgraph(self) -> typing.Callable[ - [metadata_service.QueryContextLineageSubgraphRequest], - typing.Union[ - lineage_subgraph.LineageSubgraph, - typing.Awaitable[lineage_subgraph.LineageSubgraph] - ]]: + def query_context_lineage_subgraph( + self, + ) -> typing.Callable[ + [metadata_service.QueryContextLineageSubgraphRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph], + ], + ]: raise NotImplementedError() @property - def create_execution(self) -> typing.Callable[ - [metadata_service.CreateExecutionRequest], - typing.Union[ - gca_execution.Execution, - typing.Awaitable[gca_execution.Execution] - ]]: + def create_execution( + self, + ) -> typing.Callable[ + [metadata_service.CreateExecutionRequest], + typing.Union[ + gca_execution.Execution, typing.Awaitable[gca_execution.Execution] + ], + ]: raise NotImplementedError() @property - def get_execution(self) -> typing.Callable[ - [metadata_service.GetExecutionRequest], - typing.Union[ - execution.Execution, - typing.Awaitable[execution.Execution] - ]]: + def get_execution( + self, + ) -> typing.Callable[ + [metadata_service.GetExecutionRequest], + typing.Union[execution.Execution, typing.Awaitable[execution.Execution]], + ]: raise NotImplementedError() @property - def list_executions(self) -> typing.Callable[ - [metadata_service.ListExecutionsRequest], - typing.Union[ - metadata_service.ListExecutionsResponse, - typing.Awaitable[metadata_service.ListExecutionsResponse] - ]]: + def list_executions( + self, + ) -> typing.Callable[ + [metadata_service.ListExecutionsRequest], + typing.Union[ + metadata_service.ListExecutionsResponse, + typing.Awaitable[metadata_service.ListExecutionsResponse], + ], + ]: raise NotImplementedError() @property - def update_execution(self) -> typing.Callable[ - [metadata_service.UpdateExecutionRequest], - typing.Union[ - gca_execution.Execution, - typing.Awaitable[gca_execution.Execution] - ]]: + def update_execution( + self, + ) -> typing.Callable[ + [metadata_service.UpdateExecutionRequest], + typing.Union[ + gca_execution.Execution, typing.Awaitable[gca_execution.Execution] + ], + ]: raise NotImplementedError() @property - def add_execution_events(self) -> typing.Callable[ - [metadata_service.AddExecutionEventsRequest], - typing.Union[ - metadata_service.AddExecutionEventsResponse, - typing.Awaitable[metadata_service.AddExecutionEventsResponse] - ]]: + def add_execution_events( + self, + ) -> typing.Callable[ + [metadata_service.AddExecutionEventsRequest], + typing.Union[ + metadata_service.AddExecutionEventsResponse, + typing.Awaitable[metadata_service.AddExecutionEventsResponse], + ], + ]: raise NotImplementedError() @property - def query_execution_inputs_and_outputs(self) -> typing.Callable[ - [metadata_service.QueryExecutionInputsAndOutputsRequest], - typing.Union[ - lineage_subgraph.LineageSubgraph, - typing.Awaitable[lineage_subgraph.LineageSubgraph] - ]]: + def query_execution_inputs_and_outputs( + self, + ) -> typing.Callable[ + [metadata_service.QueryExecutionInputsAndOutputsRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph], + ], + ]: raise NotImplementedError() @property - def create_metadata_schema(self) -> typing.Callable[ - [metadata_service.CreateMetadataSchemaRequest], - typing.Union[ - gca_metadata_schema.MetadataSchema, - typing.Awaitable[gca_metadata_schema.MetadataSchema] - ]]: + def create_metadata_schema( + self, + ) -> typing.Callable[ + [metadata_service.CreateMetadataSchemaRequest], + typing.Union[ + gca_metadata_schema.MetadataSchema, + typing.Awaitable[gca_metadata_schema.MetadataSchema], + ], + ]: raise NotImplementedError() @property - def get_metadata_schema(self) -> typing.Callable[ - [metadata_service.GetMetadataSchemaRequest], - typing.Union[ - metadata_schema.MetadataSchema, - typing.Awaitable[metadata_schema.MetadataSchema] - ]]: + def get_metadata_schema( + self, + ) -> typing.Callable[ + [metadata_service.GetMetadataSchemaRequest], + typing.Union[ + metadata_schema.MetadataSchema, + typing.Awaitable[metadata_schema.MetadataSchema], + ], + ]: raise NotImplementedError() @property - def list_metadata_schemas(self) -> typing.Callable[ - [metadata_service.ListMetadataSchemasRequest], - typing.Union[ - metadata_service.ListMetadataSchemasResponse, - typing.Awaitable[metadata_service.ListMetadataSchemasResponse] - ]]: + def list_metadata_schemas( + self, + ) -> typing.Callable[ + [metadata_service.ListMetadataSchemasRequest], + typing.Union[ + metadata_service.ListMetadataSchemasResponse, + typing.Awaitable[metadata_service.ListMetadataSchemasResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'MetadataServiceTransport', -) +__all__ = ("MetadataServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py index 7cc6484f91..6d9739c3be 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -55,21 +55,24 @@ class MetadataServiceGrpcTransport(MetadataServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -181,13 +184,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -220,7 +225,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -238,17 +243,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_metadata_store(self) -> Callable[ - [metadata_service.CreateMetadataStoreRequest], - operations.Operation]: + def create_metadata_store( + self, + ) -> Callable[[metadata_service.CreateMetadataStoreRequest], operations.Operation]: r"""Return a callable for the create metadata store method over gRPC. Initializes a MetadataStore, including allocation of @@ -264,18 +267,20 @@ def create_metadata_store(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_metadata_store' not in self._stubs: - self._stubs['create_metadata_store'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataStore', + if "create_metadata_store" not in self._stubs: + self._stubs["create_metadata_store"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataStore", request_serializer=metadata_service.CreateMetadataStoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_metadata_store'] + return self._stubs["create_metadata_store"] @property - def get_metadata_store(self) -> Callable[ - [metadata_service.GetMetadataStoreRequest], - metadata_store.MetadataStore]: + def get_metadata_store( + self, + ) -> Callable[ + [metadata_service.GetMetadataStoreRequest], metadata_store.MetadataStore + ]: r"""Return a callable for the get metadata store method over gRPC. Retrieves a specific MetadataStore. @@ -290,18 +295,21 @@ def get_metadata_store(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_metadata_store' not in self._stubs: - self._stubs['get_metadata_store'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataStore', + if "get_metadata_store" not in self._stubs: + self._stubs["get_metadata_store"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataStore", request_serializer=metadata_service.GetMetadataStoreRequest.serialize, response_deserializer=metadata_store.MetadataStore.deserialize, ) - return self._stubs['get_metadata_store'] + return self._stubs["get_metadata_store"] @property - def list_metadata_stores(self) -> Callable[ - [metadata_service.ListMetadataStoresRequest], - metadata_service.ListMetadataStoresResponse]: + def list_metadata_stores( + self, + ) -> Callable[ + [metadata_service.ListMetadataStoresRequest], + metadata_service.ListMetadataStoresResponse, + ]: r"""Return a callable for the list metadata stores method over gRPC. Lists MetadataStores for a Location. @@ -316,18 +324,18 @@ def list_metadata_stores(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_metadata_stores' not in self._stubs: - self._stubs['list_metadata_stores'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataStores', + if "list_metadata_stores" not in self._stubs: + self._stubs["list_metadata_stores"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataStores", request_serializer=metadata_service.ListMetadataStoresRequest.serialize, response_deserializer=metadata_service.ListMetadataStoresResponse.deserialize, ) - return self._stubs['list_metadata_stores'] + return self._stubs["list_metadata_stores"] @property - def delete_metadata_store(self) -> Callable[ - [metadata_service.DeleteMetadataStoreRequest], - operations.Operation]: + def delete_metadata_store( + self, + ) -> Callable[[metadata_service.DeleteMetadataStoreRequest], operations.Operation]: r"""Return a callable for the delete metadata store method over gRPC. Deletes a single MetadataStore. @@ -342,18 +350,18 @@ def delete_metadata_store(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_metadata_store' not in self._stubs: - self._stubs['delete_metadata_store'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteMetadataStore', + if "delete_metadata_store" not in self._stubs: + self._stubs["delete_metadata_store"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteMetadataStore", request_serializer=metadata_service.DeleteMetadataStoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_metadata_store'] + return self._stubs["delete_metadata_store"] @property - def create_artifact(self) -> Callable[ - [metadata_service.CreateArtifactRequest], - gca_artifact.Artifact]: + def create_artifact( + self, + ) -> Callable[[metadata_service.CreateArtifactRequest], gca_artifact.Artifact]: r"""Return a callable for the create artifact method over gRPC. Creates an Artifact associated with a MetadataStore. @@ -368,18 +376,18 @@ def create_artifact(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_artifact' not in self._stubs: - self._stubs['create_artifact'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateArtifact', + if "create_artifact" not in self._stubs: + self._stubs["create_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateArtifact", request_serializer=metadata_service.CreateArtifactRequest.serialize, response_deserializer=gca_artifact.Artifact.deserialize, ) - return self._stubs['create_artifact'] + return self._stubs["create_artifact"] @property - def get_artifact(self) -> Callable[ - [metadata_service.GetArtifactRequest], - artifact.Artifact]: + def get_artifact( + self, + ) -> Callable[[metadata_service.GetArtifactRequest], artifact.Artifact]: r"""Return a callable for the get artifact method over gRPC. Retrieves a specific Artifact. @@ -394,18 +402,20 @@ def get_artifact(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_artifact' not in self._stubs: - self._stubs['get_artifact'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetArtifact', + if "get_artifact" not in self._stubs: + self._stubs["get_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetArtifact", request_serializer=metadata_service.GetArtifactRequest.serialize, response_deserializer=artifact.Artifact.deserialize, ) - return self._stubs['get_artifact'] + return self._stubs["get_artifact"] @property - def list_artifacts(self) -> Callable[ - [metadata_service.ListArtifactsRequest], - metadata_service.ListArtifactsResponse]: + def list_artifacts( + self, + ) -> Callable[ + [metadata_service.ListArtifactsRequest], metadata_service.ListArtifactsResponse + ]: r"""Return a callable for the list artifacts method over gRPC. Lists Artifacts in the MetadataStore. @@ -420,18 +430,18 @@ def list_artifacts(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_artifacts' not in self._stubs: - self._stubs['list_artifacts'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListArtifacts', + if "list_artifacts" not in self._stubs: + self._stubs["list_artifacts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListArtifacts", request_serializer=metadata_service.ListArtifactsRequest.serialize, response_deserializer=metadata_service.ListArtifactsResponse.deserialize, ) - return self._stubs['list_artifacts'] + return self._stubs["list_artifacts"] @property - def update_artifact(self) -> Callable[ - [metadata_service.UpdateArtifactRequest], - gca_artifact.Artifact]: + def update_artifact( + self, + ) -> Callable[[metadata_service.UpdateArtifactRequest], gca_artifact.Artifact]: r"""Return a callable for the update artifact method over gRPC. Updates a stored Artifact. @@ -446,18 +456,18 @@ def update_artifact(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_artifact' not in self._stubs: - self._stubs['update_artifact'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateArtifact', + if "update_artifact" not in self._stubs: + self._stubs["update_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/UpdateArtifact", request_serializer=metadata_service.UpdateArtifactRequest.serialize, response_deserializer=gca_artifact.Artifact.deserialize, ) - return self._stubs['update_artifact'] + return self._stubs["update_artifact"] @property - def create_context(self) -> Callable[ - [metadata_service.CreateContextRequest], - gca_context.Context]: + def create_context( + self, + ) -> Callable[[metadata_service.CreateContextRequest], gca_context.Context]: r"""Return a callable for the create context method over gRPC. Creates a Context associated with a MetadataStore. @@ -472,18 +482,18 @@ def create_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_context' not in self._stubs: - self._stubs['create_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateContext', + if "create_context" not in self._stubs: + self._stubs["create_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateContext", request_serializer=metadata_service.CreateContextRequest.serialize, response_deserializer=gca_context.Context.deserialize, ) - return self._stubs['create_context'] + return self._stubs["create_context"] @property - def get_context(self) -> Callable[ - [metadata_service.GetContextRequest], - context.Context]: + def get_context( + self, + ) -> Callable[[metadata_service.GetContextRequest], context.Context]: r"""Return a callable for the get context method over gRPC. Retrieves a specific Context. @@ -498,18 +508,20 @@ def get_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_context' not in self._stubs: - self._stubs['get_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetContext', + if "get_context" not in self._stubs: + self._stubs["get_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetContext", request_serializer=metadata_service.GetContextRequest.serialize, response_deserializer=context.Context.deserialize, ) - return self._stubs['get_context'] + return self._stubs["get_context"] @property - def list_contexts(self) -> Callable[ - [metadata_service.ListContextsRequest], - metadata_service.ListContextsResponse]: + def list_contexts( + self, + ) -> Callable[ + [metadata_service.ListContextsRequest], metadata_service.ListContextsResponse + ]: r"""Return a callable for the list contexts method over gRPC. Lists Contexts on the MetadataStore. @@ -524,18 +536,18 @@ def list_contexts(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_contexts' not in self._stubs: - self._stubs['list_contexts'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListContexts', + if "list_contexts" not in self._stubs: + self._stubs["list_contexts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListContexts", request_serializer=metadata_service.ListContextsRequest.serialize, response_deserializer=metadata_service.ListContextsResponse.deserialize, ) - return self._stubs['list_contexts'] + return self._stubs["list_contexts"] @property - def update_context(self) -> Callable[ - [metadata_service.UpdateContextRequest], - gca_context.Context]: + def update_context( + self, + ) -> Callable[[metadata_service.UpdateContextRequest], gca_context.Context]: r"""Return a callable for the update context method over gRPC. Updates a stored Context. @@ -550,18 +562,18 @@ def update_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_context' not in self._stubs: - self._stubs['update_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateContext', + if "update_context" not in self._stubs: + self._stubs["update_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/UpdateContext", request_serializer=metadata_service.UpdateContextRequest.serialize, response_deserializer=gca_context.Context.deserialize, ) - return self._stubs['update_context'] + return self._stubs["update_context"] @property - def delete_context(self) -> Callable[ - [metadata_service.DeleteContextRequest], - operations.Operation]: + def delete_context( + self, + ) -> Callable[[metadata_service.DeleteContextRequest], operations.Operation]: r"""Return a callable for the delete context method over gRPC. Deletes a stored Context. @@ -576,18 +588,21 @@ def delete_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_context' not in self._stubs: - self._stubs['delete_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteContext', + if "delete_context" not in self._stubs: + self._stubs["delete_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteContext", request_serializer=metadata_service.DeleteContextRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_context'] + return self._stubs["delete_context"] @property - def add_context_artifacts_and_executions(self) -> Callable[ - [metadata_service.AddContextArtifactsAndExecutionsRequest], - metadata_service.AddContextArtifactsAndExecutionsResponse]: + def add_context_artifacts_and_executions( + self, + ) -> Callable[ + [metadata_service.AddContextArtifactsAndExecutionsRequest], + metadata_service.AddContextArtifactsAndExecutionsResponse, + ]: r"""Return a callable for the add context artifacts and executions method over gRPC. @@ -605,18 +620,23 @@ def add_context_artifacts_and_executions(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_context_artifacts_and_executions' not in self._stubs: - self._stubs['add_context_artifacts_and_executions'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextArtifactsAndExecutions', + if "add_context_artifacts_and_executions" not in self._stubs: + self._stubs[ + "add_context_artifacts_and_executions" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/AddContextArtifactsAndExecutions", request_serializer=metadata_service.AddContextArtifactsAndExecutionsRequest.serialize, response_deserializer=metadata_service.AddContextArtifactsAndExecutionsResponse.deserialize, ) - return self._stubs['add_context_artifacts_and_executions'] + return self._stubs["add_context_artifacts_and_executions"] @property - def add_context_children(self) -> Callable[ - [metadata_service.AddContextChildrenRequest], - metadata_service.AddContextChildrenResponse]: + def add_context_children( + self, + ) -> Callable[ + [metadata_service.AddContextChildrenRequest], + metadata_service.AddContextChildrenResponse, + ]: r"""Return a callable for the add context children method over gRPC. Adds a set of Contexts as children to a parent Context. If any @@ -635,18 +655,21 @@ def add_context_children(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_context_children' not in self._stubs: - self._stubs['add_context_children'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextChildren', + if "add_context_children" not in self._stubs: + self._stubs["add_context_children"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/AddContextChildren", request_serializer=metadata_service.AddContextChildrenRequest.serialize, response_deserializer=metadata_service.AddContextChildrenResponse.deserialize, ) - return self._stubs['add_context_children'] + return self._stubs["add_context_children"] @property - def query_context_lineage_subgraph(self) -> Callable[ - [metadata_service.QueryContextLineageSubgraphRequest], - lineage_subgraph.LineageSubgraph]: + def query_context_lineage_subgraph( + self, + ) -> Callable[ + [metadata_service.QueryContextLineageSubgraphRequest], + lineage_subgraph.LineageSubgraph, + ]: r"""Return a callable for the query context lineage subgraph method over gRPC. Retrieves Artifacts and Executions within the @@ -663,18 +686,20 @@ def query_context_lineage_subgraph(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'query_context_lineage_subgraph' not in self._stubs: - self._stubs['query_context_lineage_subgraph'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/QueryContextLineageSubgraph', + if "query_context_lineage_subgraph" not in self._stubs: + self._stubs[ + "query_context_lineage_subgraph" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/QueryContextLineageSubgraph", request_serializer=metadata_service.QueryContextLineageSubgraphRequest.serialize, response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, ) - return self._stubs['query_context_lineage_subgraph'] + return self._stubs["query_context_lineage_subgraph"] @property - def create_execution(self) -> Callable[ - [metadata_service.CreateExecutionRequest], - gca_execution.Execution]: + def create_execution( + self, + ) -> Callable[[metadata_service.CreateExecutionRequest], gca_execution.Execution]: r"""Return a callable for the create execution method over gRPC. Creates an Execution associated with a MetadataStore. @@ -689,18 +714,18 @@ def create_execution(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_execution' not in self._stubs: - self._stubs['create_execution'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateExecution', + if "create_execution" not in self._stubs: + self._stubs["create_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateExecution", request_serializer=metadata_service.CreateExecutionRequest.serialize, response_deserializer=gca_execution.Execution.deserialize, ) - return self._stubs['create_execution'] + return self._stubs["create_execution"] @property - def get_execution(self) -> Callable[ - [metadata_service.GetExecutionRequest], - execution.Execution]: + def get_execution( + self, + ) -> Callable[[metadata_service.GetExecutionRequest], execution.Execution]: r"""Return a callable for the get execution method over gRPC. Retrieves a specific Execution. @@ -715,18 +740,21 @@ def get_execution(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_execution' not in self._stubs: - self._stubs['get_execution'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetExecution', + if "get_execution" not in self._stubs: + self._stubs["get_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetExecution", request_serializer=metadata_service.GetExecutionRequest.serialize, response_deserializer=execution.Execution.deserialize, ) - return self._stubs['get_execution'] + return self._stubs["get_execution"] @property - def list_executions(self) -> Callable[ - [metadata_service.ListExecutionsRequest], - metadata_service.ListExecutionsResponse]: + def list_executions( + self, + ) -> Callable[ + [metadata_service.ListExecutionsRequest], + metadata_service.ListExecutionsResponse, + ]: r"""Return a callable for the list executions method over gRPC. Lists Executions in the MetadataStore. @@ -741,18 +769,18 @@ def list_executions(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_executions' not in self._stubs: - self._stubs['list_executions'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListExecutions', + if "list_executions" not in self._stubs: + self._stubs["list_executions"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListExecutions", request_serializer=metadata_service.ListExecutionsRequest.serialize, response_deserializer=metadata_service.ListExecutionsResponse.deserialize, ) - return self._stubs['list_executions'] + return self._stubs["list_executions"] @property - def update_execution(self) -> Callable[ - [metadata_service.UpdateExecutionRequest], - gca_execution.Execution]: + def update_execution( + self, + ) -> Callable[[metadata_service.UpdateExecutionRequest], gca_execution.Execution]: r"""Return a callable for the update execution method over gRPC. Updates a stored Execution. @@ -767,18 +795,21 @@ def update_execution(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_execution' not in self._stubs: - self._stubs['update_execution'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateExecution', + if "update_execution" not in self._stubs: + self._stubs["update_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/UpdateExecution", request_serializer=metadata_service.UpdateExecutionRequest.serialize, response_deserializer=gca_execution.Execution.deserialize, ) - return self._stubs['update_execution'] + return self._stubs["update_execution"] @property - def add_execution_events(self) -> Callable[ - [metadata_service.AddExecutionEventsRequest], - metadata_service.AddExecutionEventsResponse]: + def add_execution_events( + self, + ) -> Callable[ + [metadata_service.AddExecutionEventsRequest], + metadata_service.AddExecutionEventsResponse, + ]: r"""Return a callable for the add execution events method over gRPC. Adds Events for denoting whether each Artifact was an @@ -796,18 +827,21 @@ def add_execution_events(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_execution_events' not in self._stubs: - self._stubs['add_execution_events'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/AddExecutionEvents', + if "add_execution_events" not in self._stubs: + self._stubs["add_execution_events"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/AddExecutionEvents", request_serializer=metadata_service.AddExecutionEventsRequest.serialize, response_deserializer=metadata_service.AddExecutionEventsResponse.deserialize, ) - return self._stubs['add_execution_events'] + return self._stubs["add_execution_events"] @property - def query_execution_inputs_and_outputs(self) -> Callable[ - [metadata_service.QueryExecutionInputsAndOutputsRequest], - lineage_subgraph.LineageSubgraph]: + def query_execution_inputs_and_outputs( + self, + ) -> Callable[ + [metadata_service.QueryExecutionInputsAndOutputsRequest], + lineage_subgraph.LineageSubgraph, + ]: r"""Return a callable for the query execution inputs and outputs method over gRPC. @@ -825,18 +859,23 @@ def query_execution_inputs_and_outputs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'query_execution_inputs_and_outputs' not in self._stubs: - self._stubs['query_execution_inputs_and_outputs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/QueryExecutionInputsAndOutputs', + if "query_execution_inputs_and_outputs" not in self._stubs: + self._stubs[ + "query_execution_inputs_and_outputs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/QueryExecutionInputsAndOutputs", request_serializer=metadata_service.QueryExecutionInputsAndOutputsRequest.serialize, response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, ) - return self._stubs['query_execution_inputs_and_outputs'] + return self._stubs["query_execution_inputs_and_outputs"] @property - def create_metadata_schema(self) -> Callable[ - [metadata_service.CreateMetadataSchemaRequest], - gca_metadata_schema.MetadataSchema]: + def create_metadata_schema( + self, + ) -> Callable[ + [metadata_service.CreateMetadataSchemaRequest], + gca_metadata_schema.MetadataSchema, + ]: r"""Return a callable for the create metadata schema method over gRPC. Creates an MetadataSchema. @@ -851,18 +890,20 @@ def create_metadata_schema(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_metadata_schema' not in self._stubs: - self._stubs['create_metadata_schema'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataSchema', + if "create_metadata_schema" not in self._stubs: + self._stubs["create_metadata_schema"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataSchema", request_serializer=metadata_service.CreateMetadataSchemaRequest.serialize, response_deserializer=gca_metadata_schema.MetadataSchema.deserialize, ) - return self._stubs['create_metadata_schema'] + return self._stubs["create_metadata_schema"] @property - def get_metadata_schema(self) -> Callable[ - [metadata_service.GetMetadataSchemaRequest], - metadata_schema.MetadataSchema]: + def get_metadata_schema( + self, + ) -> Callable[ + [metadata_service.GetMetadataSchemaRequest], metadata_schema.MetadataSchema + ]: r"""Return a callable for the get metadata schema method over gRPC. Retrieves a specific MetadataSchema. @@ -877,18 +918,21 @@ def get_metadata_schema(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_metadata_schema' not in self._stubs: - self._stubs['get_metadata_schema'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataSchema', + if "get_metadata_schema" not in self._stubs: + self._stubs["get_metadata_schema"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataSchema", request_serializer=metadata_service.GetMetadataSchemaRequest.serialize, response_deserializer=metadata_schema.MetadataSchema.deserialize, ) - return self._stubs['get_metadata_schema'] + return self._stubs["get_metadata_schema"] @property - def list_metadata_schemas(self) -> Callable[ - [metadata_service.ListMetadataSchemasRequest], - metadata_service.ListMetadataSchemasResponse]: + def list_metadata_schemas( + self, + ) -> Callable[ + [metadata_service.ListMetadataSchemasRequest], + metadata_service.ListMetadataSchemasResponse, + ]: r"""Return a callable for the list metadata schemas method over gRPC. Lists MetadataSchemas. @@ -903,15 +947,13 @@ def list_metadata_schemas(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_metadata_schemas' not in self._stubs: - self._stubs['list_metadata_schemas'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataSchemas', + if "list_metadata_schemas" not in self._stubs: + self._stubs["list_metadata_schemas"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataSchemas", request_serializer=metadata_service.ListMetadataSchemasRequest.serialize, response_deserializer=metadata_service.ListMetadataSchemasResponse.deserialize, ) - return self._stubs['list_metadata_schemas'] + return self._stubs["list_metadata_schemas"] -__all__ = ( - 'MetadataServiceGrpcTransport', -) +__all__ = ("MetadataServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py index bedea761c0..ce55514b2f 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import artifact @@ -62,13 +62,15 @@ class MetadataServiceGrpcAsyncIOTransport(MetadataServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -97,22 +99,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -251,9 +255,11 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_metadata_store(self) -> Callable[ - [metadata_service.CreateMetadataStoreRequest], - Awaitable[operations.Operation]]: + def create_metadata_store( + self, + ) -> Callable[ + [metadata_service.CreateMetadataStoreRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create metadata store method over gRPC. Initializes a MetadataStore, including allocation of @@ -269,18 +275,21 @@ def create_metadata_store(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_metadata_store' not in self._stubs: - self._stubs['create_metadata_store'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataStore', + if "create_metadata_store" not in self._stubs: + self._stubs["create_metadata_store"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataStore", request_serializer=metadata_service.CreateMetadataStoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_metadata_store'] + return self._stubs["create_metadata_store"] @property - def get_metadata_store(self) -> Callable[ - [metadata_service.GetMetadataStoreRequest], - Awaitable[metadata_store.MetadataStore]]: + def get_metadata_store( + self, + ) -> Callable[ + [metadata_service.GetMetadataStoreRequest], + Awaitable[metadata_store.MetadataStore], + ]: r"""Return a callable for the get metadata store method over gRPC. Retrieves a specific MetadataStore. @@ -295,18 +304,21 @@ def get_metadata_store(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_metadata_store' not in self._stubs: - self._stubs['get_metadata_store'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataStore', + if "get_metadata_store" not in self._stubs: + self._stubs["get_metadata_store"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataStore", request_serializer=metadata_service.GetMetadataStoreRequest.serialize, response_deserializer=metadata_store.MetadataStore.deserialize, ) - return self._stubs['get_metadata_store'] + return self._stubs["get_metadata_store"] @property - def list_metadata_stores(self) -> Callable[ - [metadata_service.ListMetadataStoresRequest], - Awaitable[metadata_service.ListMetadataStoresResponse]]: + def list_metadata_stores( + self, + ) -> Callable[ + [metadata_service.ListMetadataStoresRequest], + Awaitable[metadata_service.ListMetadataStoresResponse], + ]: r"""Return a callable for the list metadata stores method over gRPC. Lists MetadataStores for a Location. @@ -321,18 +333,20 @@ def list_metadata_stores(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_metadata_stores' not in self._stubs: - self._stubs['list_metadata_stores'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataStores', + if "list_metadata_stores" not in self._stubs: + self._stubs["list_metadata_stores"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataStores", request_serializer=metadata_service.ListMetadataStoresRequest.serialize, response_deserializer=metadata_service.ListMetadataStoresResponse.deserialize, ) - return self._stubs['list_metadata_stores'] + return self._stubs["list_metadata_stores"] @property - def delete_metadata_store(self) -> Callable[ - [metadata_service.DeleteMetadataStoreRequest], - Awaitable[operations.Operation]]: + def delete_metadata_store( + self, + ) -> Callable[ + [metadata_service.DeleteMetadataStoreRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete metadata store method over gRPC. Deletes a single MetadataStore. @@ -347,18 +361,20 @@ def delete_metadata_store(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_metadata_store' not in self._stubs: - self._stubs['delete_metadata_store'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteMetadataStore', + if "delete_metadata_store" not in self._stubs: + self._stubs["delete_metadata_store"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteMetadataStore", request_serializer=metadata_service.DeleteMetadataStoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_metadata_store'] + return self._stubs["delete_metadata_store"] @property - def create_artifact(self) -> Callable[ - [metadata_service.CreateArtifactRequest], - Awaitable[gca_artifact.Artifact]]: + def create_artifact( + self, + ) -> Callable[ + [metadata_service.CreateArtifactRequest], Awaitable[gca_artifact.Artifact] + ]: r"""Return a callable for the create artifact method over gRPC. Creates an Artifact associated with a MetadataStore. @@ -373,18 +389,18 @@ def create_artifact(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_artifact' not in self._stubs: - self._stubs['create_artifact'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateArtifact', + if "create_artifact" not in self._stubs: + self._stubs["create_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateArtifact", request_serializer=metadata_service.CreateArtifactRequest.serialize, response_deserializer=gca_artifact.Artifact.deserialize, ) - return self._stubs['create_artifact'] + return self._stubs["create_artifact"] @property - def get_artifact(self) -> Callable[ - [metadata_service.GetArtifactRequest], - Awaitable[artifact.Artifact]]: + def get_artifact( + self, + ) -> Callable[[metadata_service.GetArtifactRequest], Awaitable[artifact.Artifact]]: r"""Return a callable for the get artifact method over gRPC. Retrieves a specific Artifact. @@ -399,18 +415,21 @@ def get_artifact(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_artifact' not in self._stubs: - self._stubs['get_artifact'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetArtifact', + if "get_artifact" not in self._stubs: + self._stubs["get_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetArtifact", request_serializer=metadata_service.GetArtifactRequest.serialize, response_deserializer=artifact.Artifact.deserialize, ) - return self._stubs['get_artifact'] + return self._stubs["get_artifact"] @property - def list_artifacts(self) -> Callable[ - [metadata_service.ListArtifactsRequest], - Awaitable[metadata_service.ListArtifactsResponse]]: + def list_artifacts( + self, + ) -> Callable[ + [metadata_service.ListArtifactsRequest], + Awaitable[metadata_service.ListArtifactsResponse], + ]: r"""Return a callable for the list artifacts method over gRPC. Lists Artifacts in the MetadataStore. @@ -425,18 +444,20 @@ def list_artifacts(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_artifacts' not in self._stubs: - self._stubs['list_artifacts'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListArtifacts', + if "list_artifacts" not in self._stubs: + self._stubs["list_artifacts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListArtifacts", request_serializer=metadata_service.ListArtifactsRequest.serialize, response_deserializer=metadata_service.ListArtifactsResponse.deserialize, ) - return self._stubs['list_artifacts'] + return self._stubs["list_artifacts"] @property - def update_artifact(self) -> Callable[ - [metadata_service.UpdateArtifactRequest], - Awaitable[gca_artifact.Artifact]]: + def update_artifact( + self, + ) -> Callable[ + [metadata_service.UpdateArtifactRequest], Awaitable[gca_artifact.Artifact] + ]: r"""Return a callable for the update artifact method over gRPC. Updates a stored Artifact. @@ -451,18 +472,20 @@ def update_artifact(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_artifact' not in self._stubs: - self._stubs['update_artifact'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateArtifact', + if "update_artifact" not in self._stubs: + self._stubs["update_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/UpdateArtifact", request_serializer=metadata_service.UpdateArtifactRequest.serialize, response_deserializer=gca_artifact.Artifact.deserialize, ) - return self._stubs['update_artifact'] + return self._stubs["update_artifact"] @property - def create_context(self) -> Callable[ - [metadata_service.CreateContextRequest], - Awaitable[gca_context.Context]]: + def create_context( + self, + ) -> Callable[ + [metadata_service.CreateContextRequest], Awaitable[gca_context.Context] + ]: r"""Return a callable for the create context method over gRPC. Creates a Context associated with a MetadataStore. @@ -477,18 +500,18 @@ def create_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_context' not in self._stubs: - self._stubs['create_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateContext', + if "create_context" not in self._stubs: + self._stubs["create_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateContext", request_serializer=metadata_service.CreateContextRequest.serialize, response_deserializer=gca_context.Context.deserialize, ) - return self._stubs['create_context'] + return self._stubs["create_context"] @property - def get_context(self) -> Callable[ - [metadata_service.GetContextRequest], - Awaitable[context.Context]]: + def get_context( + self, + ) -> Callable[[metadata_service.GetContextRequest], Awaitable[context.Context]]: r"""Return a callable for the get context method over gRPC. Retrieves a specific Context. @@ -503,18 +526,21 @@ def get_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_context' not in self._stubs: - self._stubs['get_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetContext', + if "get_context" not in self._stubs: + self._stubs["get_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetContext", request_serializer=metadata_service.GetContextRequest.serialize, response_deserializer=context.Context.deserialize, ) - return self._stubs['get_context'] + return self._stubs["get_context"] @property - def list_contexts(self) -> Callable[ - [metadata_service.ListContextsRequest], - Awaitable[metadata_service.ListContextsResponse]]: + def list_contexts( + self, + ) -> Callable[ + [metadata_service.ListContextsRequest], + Awaitable[metadata_service.ListContextsResponse], + ]: r"""Return a callable for the list contexts method over gRPC. Lists Contexts on the MetadataStore. @@ -529,18 +555,20 @@ def list_contexts(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_contexts' not in self._stubs: - self._stubs['list_contexts'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListContexts', + if "list_contexts" not in self._stubs: + self._stubs["list_contexts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListContexts", request_serializer=metadata_service.ListContextsRequest.serialize, response_deserializer=metadata_service.ListContextsResponse.deserialize, ) - return self._stubs['list_contexts'] + return self._stubs["list_contexts"] @property - def update_context(self) -> Callable[ - [metadata_service.UpdateContextRequest], - Awaitable[gca_context.Context]]: + def update_context( + self, + ) -> Callable[ + [metadata_service.UpdateContextRequest], Awaitable[gca_context.Context] + ]: r"""Return a callable for the update context method over gRPC. Updates a stored Context. @@ -555,18 +583,20 @@ def update_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_context' not in self._stubs: - self._stubs['update_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateContext', + if "update_context" not in self._stubs: + self._stubs["update_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/UpdateContext", request_serializer=metadata_service.UpdateContextRequest.serialize, response_deserializer=gca_context.Context.deserialize, ) - return self._stubs['update_context'] + return self._stubs["update_context"] @property - def delete_context(self) -> Callable[ - [metadata_service.DeleteContextRequest], - Awaitable[operations.Operation]]: + def delete_context( + self, + ) -> Callable[ + [metadata_service.DeleteContextRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete context method over gRPC. Deletes a stored Context. @@ -581,18 +611,21 @@ def delete_context(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_context' not in self._stubs: - self._stubs['delete_context'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/DeleteContext', + if "delete_context" not in self._stubs: + self._stubs["delete_context"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteContext", request_serializer=metadata_service.DeleteContextRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_context'] + return self._stubs["delete_context"] @property - def add_context_artifacts_and_executions(self) -> Callable[ - [metadata_service.AddContextArtifactsAndExecutionsRequest], - Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse]]: + def add_context_artifacts_and_executions( + self, + ) -> Callable[ + [metadata_service.AddContextArtifactsAndExecutionsRequest], + Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse], + ]: r"""Return a callable for the add context artifacts and executions method over gRPC. @@ -610,18 +643,23 @@ def add_context_artifacts_and_executions(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_context_artifacts_and_executions' not in self._stubs: - self._stubs['add_context_artifacts_and_executions'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextArtifactsAndExecutions', + if "add_context_artifacts_and_executions" not in self._stubs: + self._stubs[ + "add_context_artifacts_and_executions" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/AddContextArtifactsAndExecutions", request_serializer=metadata_service.AddContextArtifactsAndExecutionsRequest.serialize, response_deserializer=metadata_service.AddContextArtifactsAndExecutionsResponse.deserialize, ) - return self._stubs['add_context_artifacts_and_executions'] + return self._stubs["add_context_artifacts_and_executions"] @property - def add_context_children(self) -> Callable[ - [metadata_service.AddContextChildrenRequest], - Awaitable[metadata_service.AddContextChildrenResponse]]: + def add_context_children( + self, + ) -> Callable[ + [metadata_service.AddContextChildrenRequest], + Awaitable[metadata_service.AddContextChildrenResponse], + ]: r"""Return a callable for the add context children method over gRPC. Adds a set of Contexts as children to a parent Context. If any @@ -640,18 +678,21 @@ def add_context_children(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_context_children' not in self._stubs: - self._stubs['add_context_children'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/AddContextChildren', + if "add_context_children" not in self._stubs: + self._stubs["add_context_children"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/AddContextChildren", request_serializer=metadata_service.AddContextChildrenRequest.serialize, response_deserializer=metadata_service.AddContextChildrenResponse.deserialize, ) - return self._stubs['add_context_children'] + return self._stubs["add_context_children"] @property - def query_context_lineage_subgraph(self) -> Callable[ - [metadata_service.QueryContextLineageSubgraphRequest], - Awaitable[lineage_subgraph.LineageSubgraph]]: + def query_context_lineage_subgraph( + self, + ) -> Callable[ + [metadata_service.QueryContextLineageSubgraphRequest], + Awaitable[lineage_subgraph.LineageSubgraph], + ]: r"""Return a callable for the query context lineage subgraph method over gRPC. Retrieves Artifacts and Executions within the @@ -668,18 +709,22 @@ def query_context_lineage_subgraph(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'query_context_lineage_subgraph' not in self._stubs: - self._stubs['query_context_lineage_subgraph'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/QueryContextLineageSubgraph', + if "query_context_lineage_subgraph" not in self._stubs: + self._stubs[ + "query_context_lineage_subgraph" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/QueryContextLineageSubgraph", request_serializer=metadata_service.QueryContextLineageSubgraphRequest.serialize, response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, ) - return self._stubs['query_context_lineage_subgraph'] + return self._stubs["query_context_lineage_subgraph"] @property - def create_execution(self) -> Callable[ - [metadata_service.CreateExecutionRequest], - Awaitable[gca_execution.Execution]]: + def create_execution( + self, + ) -> Callable[ + [metadata_service.CreateExecutionRequest], Awaitable[gca_execution.Execution] + ]: r"""Return a callable for the create execution method over gRPC. Creates an Execution associated with a MetadataStore. @@ -694,18 +739,20 @@ def create_execution(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_execution' not in self._stubs: - self._stubs['create_execution'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateExecution', + if "create_execution" not in self._stubs: + self._stubs["create_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateExecution", request_serializer=metadata_service.CreateExecutionRequest.serialize, response_deserializer=gca_execution.Execution.deserialize, ) - return self._stubs['create_execution'] + return self._stubs["create_execution"] @property - def get_execution(self) -> Callable[ - [metadata_service.GetExecutionRequest], - Awaitable[execution.Execution]]: + def get_execution( + self, + ) -> Callable[ + [metadata_service.GetExecutionRequest], Awaitable[execution.Execution] + ]: r"""Return a callable for the get execution method over gRPC. Retrieves a specific Execution. @@ -720,18 +767,21 @@ def get_execution(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_execution' not in self._stubs: - self._stubs['get_execution'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetExecution', + if "get_execution" not in self._stubs: + self._stubs["get_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetExecution", request_serializer=metadata_service.GetExecutionRequest.serialize, response_deserializer=execution.Execution.deserialize, ) - return self._stubs['get_execution'] + return self._stubs["get_execution"] @property - def list_executions(self) -> Callable[ - [metadata_service.ListExecutionsRequest], - Awaitable[metadata_service.ListExecutionsResponse]]: + def list_executions( + self, + ) -> Callable[ + [metadata_service.ListExecutionsRequest], + Awaitable[metadata_service.ListExecutionsResponse], + ]: r"""Return a callable for the list executions method over gRPC. Lists Executions in the MetadataStore. @@ -746,18 +796,20 @@ def list_executions(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_executions' not in self._stubs: - self._stubs['list_executions'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListExecutions', + if "list_executions" not in self._stubs: + self._stubs["list_executions"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListExecutions", request_serializer=metadata_service.ListExecutionsRequest.serialize, response_deserializer=metadata_service.ListExecutionsResponse.deserialize, ) - return self._stubs['list_executions'] + return self._stubs["list_executions"] @property - def update_execution(self) -> Callable[ - [metadata_service.UpdateExecutionRequest], - Awaitable[gca_execution.Execution]]: + def update_execution( + self, + ) -> Callable[ + [metadata_service.UpdateExecutionRequest], Awaitable[gca_execution.Execution] + ]: r"""Return a callable for the update execution method over gRPC. Updates a stored Execution. @@ -772,18 +824,21 @@ def update_execution(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_execution' not in self._stubs: - self._stubs['update_execution'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/UpdateExecution', + if "update_execution" not in self._stubs: + self._stubs["update_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/UpdateExecution", request_serializer=metadata_service.UpdateExecutionRequest.serialize, response_deserializer=gca_execution.Execution.deserialize, ) - return self._stubs['update_execution'] + return self._stubs["update_execution"] @property - def add_execution_events(self) -> Callable[ - [metadata_service.AddExecutionEventsRequest], - Awaitable[metadata_service.AddExecutionEventsResponse]]: + def add_execution_events( + self, + ) -> Callable[ + [metadata_service.AddExecutionEventsRequest], + Awaitable[metadata_service.AddExecutionEventsResponse], + ]: r"""Return a callable for the add execution events method over gRPC. Adds Events for denoting whether each Artifact was an @@ -801,18 +856,21 @@ def add_execution_events(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_execution_events' not in self._stubs: - self._stubs['add_execution_events'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/AddExecutionEvents', + if "add_execution_events" not in self._stubs: + self._stubs["add_execution_events"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/AddExecutionEvents", request_serializer=metadata_service.AddExecutionEventsRequest.serialize, response_deserializer=metadata_service.AddExecutionEventsResponse.deserialize, ) - return self._stubs['add_execution_events'] + return self._stubs["add_execution_events"] @property - def query_execution_inputs_and_outputs(self) -> Callable[ - [metadata_service.QueryExecutionInputsAndOutputsRequest], - Awaitable[lineage_subgraph.LineageSubgraph]]: + def query_execution_inputs_and_outputs( + self, + ) -> Callable[ + [metadata_service.QueryExecutionInputsAndOutputsRequest], + Awaitable[lineage_subgraph.LineageSubgraph], + ]: r"""Return a callable for the query execution inputs and outputs method over gRPC. @@ -830,18 +888,23 @@ def query_execution_inputs_and_outputs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'query_execution_inputs_and_outputs' not in self._stubs: - self._stubs['query_execution_inputs_and_outputs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/QueryExecutionInputsAndOutputs', + if "query_execution_inputs_and_outputs" not in self._stubs: + self._stubs[ + "query_execution_inputs_and_outputs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/QueryExecutionInputsAndOutputs", request_serializer=metadata_service.QueryExecutionInputsAndOutputsRequest.serialize, response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, ) - return self._stubs['query_execution_inputs_and_outputs'] + return self._stubs["query_execution_inputs_and_outputs"] @property - def create_metadata_schema(self) -> Callable[ - [metadata_service.CreateMetadataSchemaRequest], - Awaitable[gca_metadata_schema.MetadataSchema]]: + def create_metadata_schema( + self, + ) -> Callable[ + [metadata_service.CreateMetadataSchemaRequest], + Awaitable[gca_metadata_schema.MetadataSchema], + ]: r"""Return a callable for the create metadata schema method over gRPC. Creates an MetadataSchema. @@ -856,18 +919,21 @@ def create_metadata_schema(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_metadata_schema' not in self._stubs: - self._stubs['create_metadata_schema'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataSchema', + if "create_metadata_schema" not in self._stubs: + self._stubs["create_metadata_schema"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/CreateMetadataSchema", request_serializer=metadata_service.CreateMetadataSchemaRequest.serialize, response_deserializer=gca_metadata_schema.MetadataSchema.deserialize, ) - return self._stubs['create_metadata_schema'] + return self._stubs["create_metadata_schema"] @property - def get_metadata_schema(self) -> Callable[ - [metadata_service.GetMetadataSchemaRequest], - Awaitable[metadata_schema.MetadataSchema]]: + def get_metadata_schema( + self, + ) -> Callable[ + [metadata_service.GetMetadataSchemaRequest], + Awaitable[metadata_schema.MetadataSchema], + ]: r"""Return a callable for the get metadata schema method over gRPC. Retrieves a specific MetadataSchema. @@ -882,18 +948,21 @@ def get_metadata_schema(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_metadata_schema' not in self._stubs: - self._stubs['get_metadata_schema'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataSchema', + if "get_metadata_schema" not in self._stubs: + self._stubs["get_metadata_schema"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/GetMetadataSchema", request_serializer=metadata_service.GetMetadataSchemaRequest.serialize, response_deserializer=metadata_schema.MetadataSchema.deserialize, ) - return self._stubs['get_metadata_schema'] + return self._stubs["get_metadata_schema"] @property - def list_metadata_schemas(self) -> Callable[ - [metadata_service.ListMetadataSchemasRequest], - Awaitable[metadata_service.ListMetadataSchemasResponse]]: + def list_metadata_schemas( + self, + ) -> Callable[ + [metadata_service.ListMetadataSchemasRequest], + Awaitable[metadata_service.ListMetadataSchemasResponse], + ]: r"""Return a callable for the list metadata schemas method over gRPC. Lists MetadataSchemas. @@ -908,15 +977,13 @@ def list_metadata_schemas(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_metadata_schemas' not in self._stubs: - self._stubs['list_metadata_schemas'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataSchemas', + if "list_metadata_schemas" not in self._stubs: + self._stubs["list_metadata_schemas"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/ListMetadataSchemas", request_serializer=metadata_service.ListMetadataSchemasRequest.serialize, response_deserializer=metadata_service.ListMetadataSchemasResponse.deserialize, ) - return self._stubs['list_metadata_schemas'] + return self._stubs["list_metadata_schemas"] -__all__ = ( - 'MetadataServiceGrpcAsyncIOTransport', -) +__all__ = ("MetadataServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py index c533a12b45..1d6216d1f7 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import MigrationServiceAsyncClient __all__ = ( - 'MigrationServiceClient', - 'MigrationServiceAsyncClient', + "MigrationServiceClient", + "MigrationServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index d79e43c9c1..c4db3f14d7 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -51,7 +51,9 @@ class MigrationServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) - parse_annotated_dataset_path = staticmethod(MigrationServiceClient.parse_annotated_dataset_path) + parse_annotated_dataset_path = staticmethod( + MigrationServiceClient.parse_annotated_dataset_path + ) dataset_path = staticmethod(MigrationServiceClient.dataset_path) parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) dataset_path = staticmethod(MigrationServiceClient.dataset_path) @@ -65,20 +67,34 @@ class MigrationServiceAsyncClient: version_path = staticmethod(MigrationServiceClient.version_path) parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) - common_billing_account_path = staticmethod(MigrationServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(MigrationServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + MigrationServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + MigrationServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(MigrationServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + MigrationServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(MigrationServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(MigrationServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + MigrationServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + MigrationServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(MigrationServiceClient.common_project_path) - parse_common_project_path = staticmethod(MigrationServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + MigrationServiceClient.parse_common_project_path + ) common_location_path = staticmethod(MigrationServiceClient.common_location_path) - parse_common_location_path = staticmethod(MigrationServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + MigrationServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -121,14 +137,18 @@ def transport(self) -> MigrationServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient)) + get_transport_class = functools.partial( + type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, MigrationServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -167,17 +187,17 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesAsyncPager: + async def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -218,8 +238,10 @@ async def search_migratable_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = migration_service.SearchMigratableResourcesRequest(request) @@ -240,40 +262,33 @@ async def search_migratable_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.SearchMigratableResourcesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -322,8 +337,10 @@ async def batch_migrate_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = migration_service.BatchMigrateResourcesRequest(request) @@ -347,18 +364,11 @@ async def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -372,21 +382,14 @@ async def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceAsyncClient', -) +__all__ = ("MigrationServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index a636962692..a1f5e7f79f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -50,13 +50,14 @@ class MigrationServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] - _transport_registry['grpc'] = MigrationServiceGrpcTransport - _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[MigrationServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry["grpc"] = MigrationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: """Return an appropriate transport class. Args: @@ -110,7 +111,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -145,9 +146,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MigrationServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -162,143 +162,183 @@ def transport(self) -> MigrationServiceTransport: return self._transport @staticmethod - def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: + def annotated_dataset_path( + project: str, dataset: str, annotated_dataset: str, + ) -> str: """Return a fully-qualified annotated_dataset string.""" - return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) @staticmethod - def parse_annotated_dataset_path(path: str) -> Dict[str,str]: + def parse_annotated_dataset_path(path: str) -> Dict[str, str]: """Parse a annotated_dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def version_path(project: str,model: str,version: str,) -> str: + def version_path(project: str, model: str, version: str,) -> str: """Return a fully-qualified version string.""" - return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + return "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) @staticmethod - def parse_version_path(path: str) -> Dict[str,str]: + def parse_version_path(path: str) -> Dict[str, str]: """Parse a version path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -342,7 +382,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -352,7 +394,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -364,7 +408,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -376,8 +422,10 @@ def __init__(self, *, if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -396,14 +444,15 @@ def __init__(self, *, client_info=client_info, ) - def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesPager: + def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -444,8 +493,10 @@ def search_migratable_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.SearchMigratableResourcesRequest. @@ -462,45 +513,40 @@ def search_migratable_resources(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] + rpc = self._transport._wrapped_methods[ + self._transport.search_migratable_resources + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchMigratableResourcesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -549,8 +595,10 @@ def batch_migrate_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.BatchMigrateResourcesRequest. @@ -574,18 +622,11 @@ def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation.from_gapic( @@ -599,21 +640,14 @@ def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceClient', -) +__all__ = ("MigrationServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py index d25339203b..f0a1dfa43f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import migratable_resource from google.cloud.aiplatform_v1beta1.types import migration_service @@ -38,12 +47,15 @@ class SearchMigratableResourcesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., migration_service.SearchMigratableResourcesResponse], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: yield from page.migratable_resources def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class SearchMigratableResourcesAsyncPager: @@ -97,12 +109,17 @@ class SearchMigratableResourcesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[migration_service.SearchMigratableResourcesResponse]], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[migration_service.SearchMigratableResourcesResponse] + ], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +141,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + async def pages( + self, + ) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +159,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py index 9fb765fdcc..38c72756f6 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] -_transport_registry['grpc'] = MigrationServiceGrpcTransport -_transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = MigrationServiceGrpcTransport +_transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport __all__ = ( - 'MigrationServiceTransport', - 'MigrationServiceGrpcTransport', - 'MigrationServiceGrpcAsyncIOTransport', + "MigrationServiceTransport", + "MigrationServiceGrpcTransport", + "MigrationServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py index ba00adae0e..f3324f22c6 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -33,29 +33,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class MigrationServiceTransport(abc.ABC): """Abstract transport class for MigrationService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -78,8 +78,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -88,17 +88,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -116,7 +118,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -125,24 +126,25 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def search_migratable_resources(self) -> typing.Callable[ - [migration_service.SearchMigratableResourcesRequest], - typing.Union[ - migration_service.SearchMigratableResourcesResponse, - typing.Awaitable[migration_service.SearchMigratableResourcesResponse] - ]]: + def search_migratable_resources( + self, + ) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse], + ], + ]: raise NotImplementedError() @property - def batch_migrate_resources(self) -> typing.Callable[ - [migration_service.BatchMigrateResourcesRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def batch_migrate_resources( + self, + ) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'MigrationServiceTransport', -) +__all__ = ("MigrationServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index 28a61272bf..7c63224a7a 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -47,21 +47,24 @@ class MigrationServiceGrpcTransport(MigrationServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -173,13 +176,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -212,7 +217,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -230,17 +235,18 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def search_migratable_resources(self) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - migration_service.SearchMigratableResourcesResponse]: + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse, + ]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -258,18 +264,20 @@ def search_migratable_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_migratable_resources' not in self._stubs: - self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs['search_migratable_resources'] + return self._stubs["search_migratable_resources"] @property - def batch_migrate_resources(self) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - operations.Operation]: + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], operations.Operation + ]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -286,15 +294,13 @@ def batch_migrate_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_migrate_resources' not in self._stubs: - self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_migrate_resources'] + return self._stubs["batch_migrate_resources"] -__all__ = ( - 'MigrationServiceGrpcTransport', -) +__all__ = ("MigrationServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index 4648d86616..100739ea7e 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import migration_service @@ -54,13 +54,15 @@ class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -89,22 +91,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -243,9 +247,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def search_migratable_resources(self) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - Awaitable[migration_service.SearchMigratableResourcesResponse]]: + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse], + ]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -263,18 +270,21 @@ def search_migratable_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_migratable_resources' not in self._stubs: - self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs['search_migratable_resources'] + return self._stubs["search_migratable_resources"] @property - def batch_migrate_resources(self) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - Awaitable[operations.Operation]]: + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -291,15 +301,13 @@ def batch_migrate_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_migrate_resources' not in self._stubs: - self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_migrate_resources'] + return self._stubs["batch_migrate_resources"] -__all__ = ( - 'MigrationServiceGrpcAsyncIOTransport', -) +__all__ = ("MigrationServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py index 3ee8fc6e9e..b39295ebfe 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import ModelServiceAsyncClient __all__ = ( - 'ModelServiceClient', - 'ModelServiceAsyncClient', + "ModelServiceClient", + "ModelServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index 72cfd1e4e4..a901ead2b1 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -63,26 +63,44 @@ class ModelServiceAsyncClient: model_path = staticmethod(ModelServiceClient.model_path) parse_model_path = staticmethod(ModelServiceClient.parse_model_path) model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) - parse_model_evaluation_path = staticmethod(ModelServiceClient.parse_model_evaluation_path) - model_evaluation_slice_path = staticmethod(ModelServiceClient.model_evaluation_slice_path) - parse_model_evaluation_slice_path = staticmethod(ModelServiceClient.parse_model_evaluation_slice_path) + parse_model_evaluation_path = staticmethod( + ModelServiceClient.parse_model_evaluation_path + ) + model_evaluation_slice_path = staticmethod( + ModelServiceClient.model_evaluation_slice_path + ) + parse_model_evaluation_slice_path = staticmethod( + ModelServiceClient.parse_model_evaluation_slice_path + ) training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod(ModelServiceClient.parse_training_pipeline_path) + parse_training_pipeline_path = staticmethod( + ModelServiceClient.parse_training_pipeline_path + ) - common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + ModelServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(ModelServiceClient.common_folder_path) parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) common_organization_path = staticmethod(ModelServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + ModelServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(ModelServiceClient.common_project_path) - parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + ModelServiceClient.parse_common_project_path + ) common_location_path = staticmethod(ModelServiceClient.common_location_path) - parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + ModelServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -125,14 +143,18 @@ def transport(self) -> ModelServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) + get_transport_class = functools.partial( + type(ModelServiceClient).get_transport_class, type(ModelServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -171,18 +193,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Uploads a Model artifact into AI Platform. Args: @@ -225,8 +247,10 @@ async def upload_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.UploadModelRequest(request) @@ -249,18 +273,11 @@ async def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -273,14 +290,15 @@ async def upload_model(self, # Done; return the response. return response - async def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + async def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -310,8 +328,10 @@ async def get_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelRequest(request) @@ -332,30 +352,24 @@ async def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: + async def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: r"""Lists Models in a Location. Args: @@ -391,8 +405,10 @@ async def list_models(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelsRequest(request) @@ -413,40 +429,31 @@ async def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + async def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -484,8 +491,10 @@ async def update_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.UpdateModelRequest(request) @@ -508,30 +517,26 @@ async def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -579,8 +584,10 @@ async def delete_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.DeleteModelRequest(request) @@ -601,18 +608,11 @@ async def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -625,15 +625,16 @@ async def delete_model(self, # Done; return the response. return response - async def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -681,8 +682,10 @@ async def export_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ExportModelRequest(request) @@ -705,18 +708,11 @@ async def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -729,14 +725,15 @@ async def export_model(self, # Done; return the response. return response - async def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + async def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -771,8 +768,10 @@ async def get_model_evaluation(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelEvaluationRequest(request) @@ -793,30 +792,24 @@ async def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsAsyncPager: + async def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: r"""Lists ModelEvaluations in a Model. Args: @@ -852,8 +845,10 @@ async def list_model_evaluations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelEvaluationsRequest(request) @@ -874,39 +869,30 @@ async def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + async def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -941,8 +927,10 @@ async def get_model_evaluation_slice(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelEvaluationSliceRequest(request) @@ -963,30 +951,24 @@ async def get_model_evaluation_slice(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesAsyncPager: + async def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1022,8 +1004,10 @@ async def list_model_evaluation_slices(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelEvaluationSlicesRequest(request) @@ -1044,47 +1028,30 @@ async def list_model_evaluation_slices(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationSlicesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceAsyncClient', -) +__all__ = ("ModelServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 29e081bc10..8b14e16e0b 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -61,13 +61,12 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry['grpc'] = ModelServiceGrpcTransport - _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[ModelServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -118,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -153,9 +152,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: ModelServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -170,121 +168,162 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: + def model_evaluation_path( + project: str, location: str, model: str, evaluation: str, + ) -> str: """Return a fully-qualified model_evaluation string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) @staticmethod - def parse_model_evaluation_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_path(path: str) -> Dict[str, str]: """Parse a model_evaluation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: + def model_evaluation_slice_path( + project: str, location: str, model: str, evaluation: str, slice: str, + ) -> str: """Return a fully-qualified model_evaluation_slice string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) @staticmethod - def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: """Parse a model_evaluation_slice path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -328,7 +367,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -338,7 +379,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -350,7 +393,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -362,8 +407,10 @@ def __init__(self, *, if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -382,15 +429,16 @@ def __init__(self, *, client_info=client_info, ) - def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -433,8 +481,10 @@ def upload_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UploadModelRequest. @@ -458,18 +508,11 @@ def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -482,14 +525,15 @@ def upload_model(self, # Done; return the response. return response - def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -519,8 +563,10 @@ def get_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -542,30 +588,24 @@ def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -601,8 +641,10 @@ def list_models(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -624,40 +666,31 @@ def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -695,8 +728,10 @@ def update_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateModelRequest. @@ -720,30 +755,26 @@ def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -791,8 +822,10 @@ def delete_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteModelRequest. @@ -814,18 +847,11 @@ def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -838,15 +864,16 @@ def delete_model(self, # Done; return the response. return response - def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -894,8 +921,10 @@ def export_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ExportModelRequest. @@ -919,18 +948,11 @@ def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -943,14 +965,15 @@ def export_model(self, # Done; return the response. return response - def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -985,8 +1008,10 @@ def get_model_evaluation(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationRequest. @@ -1008,30 +1033,24 @@ def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -1067,8 +1086,10 @@ def list_model_evaluations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationsRequest. @@ -1090,39 +1111,30 @@ def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -1157,8 +1169,10 @@ def get_model_evaluation_slice(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationSliceRequest. @@ -1175,35 +1189,31 @@ def get_model_evaluation_slice(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] + rpc = self._transport._wrapped_methods[ + self._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1239,8 +1249,10 @@ def list_model_evaluation_slices(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationSlicesRequest. @@ -1257,52 +1269,37 @@ def list_model_evaluation_slices(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] + rpc = self._transport._wrapped_methods[ + self._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceClient', -) +__all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index c4d4d8696b..eb547a5f9f 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model_evaluation @@ -40,12 +49,15 @@ class ListModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelsResponse], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -79,7 +91,7 @@ def __iter__(self) -> Iterable[model.Model]: yield from page.models def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelsAsyncPager: @@ -99,12 +111,15 @@ class ListModelsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -142,7 +157,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationsPager: @@ -162,12 +177,15 @@ class ListModelEvaluationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelEvaluationsResponse], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationsResponse], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -201,7 +219,7 @@ def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: yield from page.model_evaluations def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationsAsyncPager: @@ -221,12 +239,15 @@ class ListModelEvaluationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -264,7 +285,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesPager: @@ -284,12 +305,15 @@ class ListModelEvaluationSlicesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelEvaluationSlicesResponse], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -323,7 +347,7 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: yield from page.model_evaluation_slices def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesAsyncPager: @@ -343,12 +367,17 @@ class ListModelEvaluationSlicesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationSlicesResponse]], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] + ], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -370,7 +399,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + async def pages( + self, + ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -386,4 +417,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py index 833862a1d6..5d1cb51abc 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry['grpc'] = ModelServiceGrpcTransport -_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport __all__ = ( - 'ModelServiceTransport', - 'ModelServiceGrpcTransport', - 'ModelServiceGrpcAsyncIOTransport', + "ModelServiceTransport", + "ModelServiceGrpcTransport", + "ModelServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index 40426aa4bd..37d2b7a4e7 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -37,29 +37,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -82,8 +82,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -92,17 +92,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -111,39 +113,25 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.upload_model: gapic_v1.method.wrap_method( - self.upload_model, - default_timeout=5.0, - client_info=client_info, + self.upload_model, default_timeout=5.0, client_info=client_info, ), self.get_model: gapic_v1.method.wrap_method( - self.get_model, - default_timeout=5.0, - client_info=client_info, + self.get_model, default_timeout=5.0, client_info=client_info, ), self.list_models: gapic_v1.method.wrap_method( - self.list_models, - default_timeout=5.0, - client_info=client_info, + self.list_models, default_timeout=5.0, client_info=client_info, ), self.update_model: gapic_v1.method.wrap_method( - self.update_model, - default_timeout=5.0, - client_info=client_info, + self.update_model, default_timeout=5.0, client_info=client_info, ), self.delete_model: gapic_v1.method.wrap_method( - self.delete_model, - default_timeout=5.0, - client_info=client_info, + self.delete_model, default_timeout=5.0, client_info=client_info, ), self.export_model: gapic_v1.method.wrap_method( - self.export_model, - default_timeout=5.0, - client_info=client_info, + self.export_model, default_timeout=5.0, client_info=client_info, ), self.get_model_evaluation: gapic_v1.method.wrap_method( - self.get_model_evaluation, - default_timeout=5.0, - client_info=client_info, + self.get_model_evaluation, default_timeout=5.0, client_info=client_info, ), self.list_model_evaluations: gapic_v1.method.wrap_method( self.list_model_evaluations, @@ -160,7 +148,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), - } @property @@ -169,96 +156,109 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def upload_model(self) -> typing.Callable[ - [model_service.UploadModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def upload_model( + self, + ) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model(self) -> typing.Callable[ - [model_service.GetModelRequest], - typing.Union[ - model.Model, - typing.Awaitable[model.Model] - ]]: + def get_model( + self, + ) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[model.Model, typing.Awaitable[model.Model]], + ]: raise NotImplementedError() @property - def list_models(self) -> typing.Callable[ - [model_service.ListModelsRequest], - typing.Union[ - model_service.ListModelsResponse, - typing.Awaitable[model_service.ListModelsResponse] - ]]: + def list_models( + self, + ) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse], + ], + ]: raise NotImplementedError() @property - def update_model(self) -> typing.Callable[ - [model_service.UpdateModelRequest], - typing.Union[ - gca_model.Model, - typing.Awaitable[gca_model.Model] - ]]: + def update_model( + self, + ) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], + ]: raise NotImplementedError() @property - def delete_model(self) -> typing.Callable[ - [model_service.DeleteModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_model( + self, + ) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_model(self) -> typing.Callable[ - [model_service.ExportModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_model( + self, + ) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model_evaluation(self) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], - typing.Union[ - model_evaluation.ModelEvaluation, - typing.Awaitable[model_evaluation.ModelEvaluation] - ]]: + def get_model_evaluation( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation], + ], + ]: raise NotImplementedError() @property - def list_model_evaluations(self) -> typing.Callable[ - [model_service.ListModelEvaluationsRequest], - typing.Union[ - model_service.ListModelEvaluationsResponse, - typing.Awaitable[model_service.ListModelEvaluationsResponse] - ]]: + def list_model_evaluations( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse], + ], + ]: raise NotImplementedError() @property - def get_model_evaluation_slice(self) -> typing.Callable[ - [model_service.GetModelEvaluationSliceRequest], - typing.Union[ - model_evaluation_slice.ModelEvaluationSlice, - typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice] - ]]: + def get_model_evaluation_slice( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ], + ]: raise NotImplementedError() @property - def list_model_evaluation_slices(self) -> typing.Callable[ - [model_service.ListModelEvaluationSlicesRequest], - typing.Union[ - model_service.ListModelEvaluationSlicesResponse, - typing.Awaitable[model_service.ListModelEvaluationSlicesResponse] - ]]: + def list_model_evaluation_slices( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'ModelServiceTransport', -) +__all__ = ("ModelServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index 85db2fddd7..2cbac70e87 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -49,21 +49,24 @@ class ModelServiceGrpcTransport(ModelServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -175,13 +178,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -214,7 +219,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -232,17 +237,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def upload_model(self) -> Callable[ - [model_service.UploadModelRequest], - operations.Operation]: + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], operations.Operation]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -257,18 +260,16 @@ def upload_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'upload_model' not in self._stubs: - self._stubs['upload_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['upload_model'] + return self._stubs["upload_model"] @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - model.Model]: + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -283,18 +284,18 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - model_service.ListModelsResponse]: + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -309,18 +310,18 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def update_model(self) -> Callable[ - [model_service.UpdateModelRequest], - gca_model.Model]: + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], gca_model.Model]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -335,18 +336,18 @@ def update_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model' not in self._stubs: - self._stubs['update_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs['update_model'] + return self._stubs["update_model"] @property - def delete_model(self) -> Callable[ - [model_service.DeleteModelRequest], - operations.Operation]: + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], operations.Operation]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -363,18 +364,18 @@ def delete_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model' not in self._stubs: - self._stubs['delete_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model'] + return self._stubs["delete_model"] @property - def export_model(self) -> Callable[ - [model_service.ExportModelRequest], - operations.Operation]: + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], operations.Operation]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -392,18 +393,20 @@ def export_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_model' not in self._stubs: - self._stubs['export_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_model'] + return self._stubs["export_model"] @property - def get_model_evaluation(self) -> Callable[ - [model_service.GetModelEvaluationRequest], - model_evaluation.ModelEvaluation]: + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation + ]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -418,18 +421,21 @@ def get_model_evaluation(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation' not in self._stubs: - self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs['get_model_evaluation'] + return self._stubs["get_model_evaluation"] @property - def list_model_evaluations(self) -> Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse]: + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse, + ]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -444,18 +450,21 @@ def list_model_evaluations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluations' not in self._stubs: - self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs['list_model_evaluations'] + return self._stubs["list_model_evaluations"] @property - def get_model_evaluation_slice(self) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice]: + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + model_evaluation_slice.ModelEvaluationSlice, + ]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -470,18 +479,21 @@ def get_model_evaluation_slice(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation_slice' not in self._stubs: - self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs['get_model_evaluation_slice'] + return self._stubs["get_model_evaluation_slice"] @property - def list_model_evaluation_slices(self) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse]: + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse, + ]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -496,15 +508,13 @@ def list_model_evaluation_slices(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluation_slices' not in self._stubs: - self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs['list_model_evaluation_slices'] + return self._stubs["list_model_evaluation_slices"] -__all__ = ( - 'ModelServiceGrpcTransport', -) +__all__ = ("ModelServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index bd8ae232f9..700014be02 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import model @@ -56,13 +56,15 @@ class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -91,22 +93,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -245,9 +249,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def upload_model(self) -> Callable[ - [model_service.UploadModelRequest], - Awaitable[operations.Operation]]: + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -262,18 +266,18 @@ def upload_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'upload_model' not in self._stubs: - self._stubs['upload_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['upload_model'] + return self._stubs["upload_model"] @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - Awaitable[model.Model]]: + def get_model( + self, + ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -288,18 +292,20 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - Awaitable[model_service.ListModelsResponse]]: + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] + ]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -314,18 +320,18 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def update_model(self) -> Callable[ - [model_service.UpdateModelRequest], - Awaitable[gca_model.Model]]: + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -340,18 +346,18 @@ def update_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model' not in self._stubs: - self._stubs['update_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs['update_model'] + return self._stubs["update_model"] @property - def delete_model(self) -> Callable[ - [model_service.DeleteModelRequest], - Awaitable[operations.Operation]]: + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -368,18 +374,18 @@ def delete_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model' not in self._stubs: - self._stubs['delete_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model'] + return self._stubs["delete_model"] @property - def export_model(self) -> Callable[ - [model_service.ExportModelRequest], - Awaitable[operations.Operation]]: + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -397,18 +403,21 @@ def export_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_model' not in self._stubs: - self._stubs['export_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_model'] + return self._stubs["export_model"] @property - def get_model_evaluation(self) -> Callable[ - [model_service.GetModelEvaluationRequest], - Awaitable[model_evaluation.ModelEvaluation]]: + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation], + ]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -423,18 +432,21 @@ def get_model_evaluation(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation' not in self._stubs: - self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs['get_model_evaluation'] + return self._stubs["get_model_evaluation"] @property - def list_model_evaluations(self) -> Callable[ - [model_service.ListModelEvaluationsRequest], - Awaitable[model_service.ListModelEvaluationsResponse]]: + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse], + ]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -449,18 +461,21 @@ def list_model_evaluations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluations' not in self._stubs: - self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs['list_model_evaluations'] + return self._stubs["list_model_evaluations"] @property - def get_model_evaluation_slice(self) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - Awaitable[model_evaluation_slice.ModelEvaluationSlice]]: + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -475,18 +490,21 @@ def get_model_evaluation_slice(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation_slice' not in self._stubs: - self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs['get_model_evaluation_slice'] + return self._stubs["get_model_evaluation_slice"] @property - def list_model_evaluation_slices(self) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - Awaitable[model_service.ListModelEvaluationSlicesResponse]]: + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse], + ]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -501,15 +519,13 @@ def list_model_evaluation_slices(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluation_slices' not in self._stubs: - self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs['list_model_evaluation_slices'] + return self._stubs["list_model_evaluation_slices"] -__all__ = ( - 'ModelServiceGrpcAsyncIOTransport', -) +__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py index f7f4d9b9ac..7f02b47358 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PipelineServiceAsyncClient __all__ = ( - 'PipelineServiceClient', - 'PipelineServiceAsyncClient', + "PipelineServiceClient", + "PipelineServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 6235697be1..063153700c 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -37,7 +37,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -61,22 +63,38 @@ class PipelineServiceAsyncClient: model_path = staticmethod(PipelineServiceClient.model_path) parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod(PipelineServiceClient.parse_training_pipeline_path) + parse_training_pipeline_path = staticmethod( + PipelineServiceClient.parse_training_pipeline_path + ) - common_billing_account_path = staticmethod(PipelineServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(PipelineServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + PipelineServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PipelineServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(PipelineServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + PipelineServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(PipelineServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(PipelineServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + PipelineServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PipelineServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(PipelineServiceClient.common_project_path) - parse_common_project_path = staticmethod(PipelineServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + PipelineServiceClient.parse_common_project_path + ) common_location_path = staticmethod(PipelineServiceClient.common_location_path) - parse_common_location_path = staticmethod(PipelineServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + PipelineServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -119,14 +137,18 @@ def transport(self) -> PipelineServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient)) + get_transport_class = functools.partial( + type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -165,18 +187,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + async def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -221,8 +243,10 @@ async def create_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.CreateTrainingPipelineRequest(request) @@ -245,30 +269,24 @@ async def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + async def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -305,8 +323,10 @@ async def get_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.GetTrainingPipelineRequest(request) @@ -327,30 +347,24 @@ async def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesAsyncPager: + async def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: r"""Lists TrainingPipelines in a Location. Args: @@ -386,8 +400,10 @@ async def list_training_pipelines(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.ListTrainingPipelinesRequest(request) @@ -408,39 +424,30 @@ async def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListTrainingPipelinesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a TrainingPipeline. Args: @@ -486,8 +493,10 @@ async def delete_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.DeleteTrainingPipelineRequest(request) @@ -508,18 +517,11 @@ async def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -532,14 +534,15 @@ async def delete_training_pipeline(self, # Done; return the response. return response - async def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -578,8 +581,10 @@ async def cancel_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.CancelTrainingPipelineRequest(request) @@ -600,35 +605,23 @@ async def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceAsyncClient', -) +__all__ = ("PipelineServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 07f1ac0444..4efc2064b5 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -41,7 +41,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -59,13 +61,14 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry['grpc'] = PipelineServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PipelineServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry["grpc"] = PipelineServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -116,7 +119,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -151,9 +154,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PipelineServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -168,99 +170,122 @@ def transport(self) -> PipelineServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -304,7 +329,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -314,7 +341,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -326,7 +355,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -338,8 +369,10 @@ def __init__(self, *, if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -358,15 +391,16 @@ def __init__(self, *, client_info=client_info, ) - def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -411,8 +445,10 @@ def create_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CreateTrainingPipelineRequest. @@ -436,30 +472,24 @@ def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -496,8 +526,10 @@ def get_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.GetTrainingPipelineRequest. @@ -519,30 +551,24 @@ def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -578,8 +604,10 @@ def list_training_pipelines(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.ListTrainingPipelinesRequest. @@ -601,39 +629,30 @@ def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -679,8 +698,10 @@ def delete_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.DeleteTrainingPipelineRequest. @@ -702,18 +723,11 @@ def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -726,14 +740,15 @@ def delete_training_pipeline(self, # Done; return the response. return response - def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -772,8 +787,10 @@ def cancel_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CancelTrainingPipelineRequest. @@ -795,35 +812,23 @@ def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceClient', -) +__all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index 6de70ee1f1..db2b4dd3a1 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -38,12 +47,15 @@ class ListTrainingPipelinesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: yield from page.training_pipelines def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListTrainingPipelinesAsyncPager: @@ -97,12 +109,17 @@ class ListTrainingPipelinesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[pipeline_service.ListTrainingPipelinesResponse]], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +141,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + async def pages( + self, + ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +159,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py index f289718f83..9d4610087a 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] -_transport_registry['grpc'] = PipelineServiceGrpcTransport -_transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = PipelineServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport __all__ = ( - 'PipelineServiceTransport', - 'PipelineServiceGrpcTransport', - 'PipelineServiceGrpcAsyncIOTransport', + "PipelineServiceTransport", + "PipelineServiceGrpcTransport", + "PipelineServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 30070650b2..886219917f 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -21,14 +21,16 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -36,29 +38,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,8 +83,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -91,17 +93,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -134,7 +138,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), - } @property @@ -143,51 +146,58 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_training_pipeline(self) -> typing.Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - typing.Union[ - gca_training_pipeline.TrainingPipeline, - typing.Awaitable[gca_training_pipeline.TrainingPipeline] - ]]: + def create_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline], + ], + ]: raise NotImplementedError() @property - def get_training_pipeline(self) -> typing.Callable[ - [pipeline_service.GetTrainingPipelineRequest], - typing.Union[ - training_pipeline.TrainingPipeline, - typing.Awaitable[training_pipeline.TrainingPipeline] - ]]: + def get_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.GetTrainingPipelineRequest], + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline], + ], + ]: raise NotImplementedError() @property - def list_training_pipelines(self) -> typing.Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - typing.Union[ - pipeline_service.ListTrainingPipelinesResponse, - typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse] - ]]: + def list_training_pipelines( + self, + ) -> typing.Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ], + ]: raise NotImplementedError() @property - def delete_training_pipeline(self) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_training_pipeline(self) -> typing.Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() -__all__ = ( - 'PipelineServiceTransport', -) +__all__ = ("PipelineServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 9c024143ef..8004a9a0a7 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -18,18 +18,20 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -48,21 +50,24 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -174,13 +179,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -213,7 +220,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -231,17 +238,18 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_training_pipeline(self) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline]: + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + gca_training_pipeline.TrainingPipeline, + ]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -257,18 +265,21 @@ def create_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_training_pipeline' not in self._stubs: - self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['create_training_pipeline'] + return self._stubs["create_training_pipeline"] @property - def get_training_pipeline(self) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline]: + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + training_pipeline.TrainingPipeline, + ]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -283,18 +294,21 @@ def get_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_training_pipeline' not in self._stubs: - self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['get_training_pipeline'] + return self._stubs["get_training_pipeline"] @property - def list_training_pipelines(self) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse]: + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse, + ]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -309,18 +323,20 @@ def list_training_pipelines(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_training_pipelines' not in self._stubs: - self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs['list_training_pipelines'] + return self._stubs["list_training_pipelines"] @property - def delete_training_pipeline(self) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - operations.Operation]: + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation + ]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -335,18 +351,18 @@ def delete_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_training_pipeline' not in self._stubs: - self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_training_pipeline'] + return self._stubs["delete_training_pipeline"] @property - def cancel_training_pipeline(self) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - empty.Empty]: + def cancel_training_pipeline( + self, + ) -> Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -373,15 +389,13 @@ def cancel_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_training_pipeline' not in self._stubs: - self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_training_pipeline'] + return self._stubs["cancel_training_pipeline"] -__all__ = ( - 'PipelineServiceGrpcTransport', -) +__all__ = ("PipelineServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index 53bd371d65..a268ec1cd2 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -18,19 +18,21 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -55,13 +57,15 @@ class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -90,22 +94,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -244,9 +250,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_training_pipeline(self) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - Awaitable[gca_training_pipeline.TrainingPipeline]]: + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline], + ]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -262,18 +271,21 @@ def create_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_training_pipeline' not in self._stubs: - self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['create_training_pipeline'] + return self._stubs["create_training_pipeline"] @property - def get_training_pipeline(self) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - Awaitable[training_pipeline.TrainingPipeline]]: + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline], + ]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -288,18 +300,21 @@ def get_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_training_pipeline' not in self._stubs: - self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['get_training_pipeline'] + return self._stubs["get_training_pipeline"] @property - def list_training_pipelines(self) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - Awaitable[pipeline_service.ListTrainingPipelinesResponse]]: + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -314,18 +329,21 @@ def list_training_pipelines(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_training_pipelines' not in self._stubs: - self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs['list_training_pipelines'] + return self._stubs["list_training_pipelines"] @property - def delete_training_pipeline(self) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - Awaitable[operations.Operation]]: + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -340,18 +358,20 @@ def delete_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_training_pipeline' not in self._stubs: - self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_training_pipeline'] + return self._stubs["delete_training_pipeline"] @property - def cancel_training_pipeline(self) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - Awaitable[empty.Empty]]: + def cancel_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -378,15 +398,13 @@ def cancel_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_training_pipeline' not in self._stubs: - self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_training_pipeline'] + return self._stubs["cancel_training_pipeline"] -__all__ = ( - 'PipelineServiceGrpcAsyncIOTransport', -) +__all__ = ("PipelineServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py index d4047c335d..0c847693e0 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PredictionServiceAsyncClient __all__ = ( - 'PredictionServiceClient', - 'PredictionServiceAsyncClient', + "PredictionServiceClient", + "PredictionServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index 64b514608c..4d69a6635f 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -48,20 +48,34 @@ class PredictionServiceAsyncClient: endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) - common_billing_account_path = staticmethod(PredictionServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(PredictionServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + PredictionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PredictionServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(PredictionServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + PredictionServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(PredictionServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(PredictionServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + PredictionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PredictionServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(PredictionServiceClient.common_project_path) - parse_common_project_path = staticmethod(PredictionServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + PredictionServiceClient.parse_common_project_path + ) common_location_path = staticmethod(PredictionServiceClient.common_location_path) - parse_common_location_path = staticmethod(PredictionServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + PredictionServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -104,14 +118,18 @@ def transport(self) -> PredictionServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient)) + get_transport_class = functools.partial( + type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -150,19 +168,19 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def predict(self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + async def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -222,8 +240,10 @@ async def predict(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = prediction_service.PredictRequest(request) @@ -249,33 +269,27 @@ async def predict(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def explain(self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: + async def explain( + self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. If @@ -354,8 +368,10 @@ async def explain(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = prediction_service.ExplainRequest(request) @@ -383,38 +399,24 @@ async def explain(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PredictionServiceAsyncClient', -) +__all__ = ("PredictionServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 097cf3d0fe..042307eca1 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -48,13 +48,16 @@ class PredictionServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] - _transport_registry['grpc'] = PredictionServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PredictionServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry["grpc"] = PredictionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[PredictionServiceTransport]: """Return an appropriate transport class. Args: @@ -105,7 +108,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -140,9 +143,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PredictionServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -157,77 +159,88 @@ def transport(self) -> PredictionServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -271,7 +284,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -281,7 +296,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -293,7 +310,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -305,8 +324,10 @@ def __init__(self, *, if isinstance(transport, PredictionServiceTransport): # transport is a PredictionServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -325,16 +346,17 @@ def __init__(self, *, client_info=client_info, ) - def predict(self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -394,8 +416,10 @@ def predict(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a prediction_service.PredictRequest. @@ -421,33 +445,27 @@ def predict(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def explain(self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: + def explain( + self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. If @@ -526,8 +544,10 @@ def explain(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a prediction_service.ExplainRequest. @@ -555,38 +575,24 @@ def explain(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PredictionServiceClient', -) +__all__ = ("PredictionServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py index 15b5acb198..9ec1369a05 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] -_transport_registry['grpc'] = PredictionServiceGrpcTransport -_transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = PredictionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport __all__ = ( - 'PredictionServiceTransport', - 'PredictionServiceGrpcTransport', - 'PredictionServiceGrpcAsyncIOTransport', + "PredictionServiceTransport", + "PredictionServiceGrpcTransport", + "PredictionServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index d391018e2c..df601f6bdd 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore @@ -31,29 +31,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -76,8 +76,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -86,17 +86,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -105,37 +107,36 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.predict: gapic_v1.method.wrap_method( - self.predict, - default_timeout=5.0, - client_info=client_info, + self.predict, default_timeout=5.0, client_info=client_info, ), self.explain: gapic_v1.method.wrap_method( - self.explain, - default_timeout=5.0, - client_info=client_info, + self.explain, default_timeout=5.0, client_info=client_info, ), - } @property - def predict(self) -> typing.Callable[ - [prediction_service.PredictRequest], - typing.Union[ - prediction_service.PredictResponse, - typing.Awaitable[prediction_service.PredictResponse] - ]]: + def predict( + self, + ) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse], + ], + ]: raise NotImplementedError() @property - def explain(self) -> typing.Callable[ - [prediction_service.ExplainRequest], - typing.Union[ - prediction_service.ExplainResponse, - typing.Awaitable[prediction_service.ExplainResponse] - ]]: + def explain( + self, + ) -> typing.Callable[ + [prediction_service.ExplainRequest], + typing.Union[ + prediction_service.ExplainResponse, + typing.Awaitable[prediction_service.ExplainResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'PredictionServiceTransport', -) +__all__ = ("PredictionServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index ae5dfad093..c8be8ca59d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -18,10 +18,10 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -43,21 +43,24 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -168,13 +171,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -207,7 +212,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -217,9 +222,11 @@ def grpc_channel(self) -> grpc.Channel: return self._grpc_channel @property - def predict(self) -> Callable[ - [prediction_service.PredictRequest], - prediction_service.PredictResponse]: + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], prediction_service.PredictResponse + ]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -234,18 +241,20 @@ def predict(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'predict' not in self._stubs: - self._stubs['predict'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs['predict'] + return self._stubs["predict"] @property - def explain(self) -> Callable[ - [prediction_service.ExplainRequest], - prediction_service.ExplainResponse]: + def explain( + self, + ) -> Callable[ + [prediction_service.ExplainRequest], prediction_service.ExplainResponse + ]: r"""Return a callable for the explain method over gRPC. Perform an online explanation. @@ -271,15 +280,13 @@ def explain(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'explain' not in self._stubs: - self._stubs['explain'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', + if "explain" not in self._stubs: + self._stubs["explain"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", request_serializer=prediction_service.ExplainRequest.serialize, response_deserializer=prediction_service.ExplainResponse.deserialize, ) - return self._stubs['explain'] + return self._stubs["explain"] -__all__ = ( - 'PredictionServiceGrpcTransport', -) +__all__ = ("PredictionServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index 69fbb7edeb..8edd3f1aac 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -18,13 +18,13 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -50,13 +50,15 @@ class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -85,22 +87,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -222,9 +226,12 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def predict(self) -> Callable[ - [prediction_service.PredictRequest], - Awaitable[prediction_service.PredictResponse]]: + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse], + ]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -239,18 +246,21 @@ def predict(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'predict' not in self._stubs: - self._stubs['predict'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs['predict'] + return self._stubs["predict"] @property - def explain(self) -> Callable[ - [prediction_service.ExplainRequest], - Awaitable[prediction_service.ExplainResponse]]: + def explain( + self, + ) -> Callable[ + [prediction_service.ExplainRequest], + Awaitable[prediction_service.ExplainResponse], + ]: r"""Return a callable for the explain method over gRPC. Perform an online explanation. @@ -276,15 +286,13 @@ def explain(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'explain' not in self._stubs: - self._stubs['explain'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', + if "explain" not in self._stubs: + self._stubs["explain"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", request_serializer=prediction_service.ExplainRequest.serialize, response_deserializer=prediction_service.ExplainResponse.deserialize, ) - return self._stubs['explain'] + return self._stubs["explain"] -__all__ = ( - 'PredictionServiceGrpcAsyncIOTransport', -) +__all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py index e4247d7758..49e9cdf0a0 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import SpecialistPoolServiceAsyncClient __all__ = ( - 'SpecialistPoolServiceClient', - 'SpecialistPoolServiceAsyncClient', + "SpecialistPoolServiceClient", + "SpecialistPoolServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index a6de6886e7..6907135b53 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -57,23 +57,43 @@ class SpecialistPoolServiceAsyncClient: DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT - specialist_pool_path = staticmethod(SpecialistPoolServiceClient.specialist_pool_path) - parse_specialist_pool_path = staticmethod(SpecialistPoolServiceClient.parse_specialist_pool_path) + specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.specialist_pool_path + ) + parse_specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.parse_specialist_pool_path + ) - common_billing_account_path = staticmethod(SpecialistPoolServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(SpecialistPoolServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(SpecialistPoolServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + SpecialistPoolServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(SpecialistPoolServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(SpecialistPoolServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + SpecialistPoolServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + SpecialistPoolServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) - parse_common_project_path = staticmethod(SpecialistPoolServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + SpecialistPoolServiceClient.parse_common_project_path + ) - common_location_path = staticmethod(SpecialistPoolServiceClient.common_location_path) - parse_common_location_path = staticmethod(SpecialistPoolServiceClient.parse_common_location_path) + common_location_path = staticmethod( + SpecialistPoolServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + SpecialistPoolServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -116,14 +136,19 @@ def transport(self) -> SpecialistPoolServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(SpecialistPoolServiceClient).get_transport_class, type(SpecialistPoolServiceClient)) + get_transport_class = functools.partial( + type(SpecialistPoolServiceClient).get_transport_class, + type(SpecialistPoolServiceClient), + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -162,18 +187,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a SpecialistPool. Args: @@ -221,8 +246,10 @@ async def create_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.CreateSpecialistPoolRequest(request) @@ -245,18 +272,11 @@ async def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -269,14 +289,15 @@ async def create_specialist_pool(self, # Done; return the response. return response - async def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + async def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -318,8 +339,10 @@ async def get_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.GetSpecialistPoolRequest(request) @@ -340,30 +363,24 @@ async def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsAsyncPager: + async def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: r"""Lists SpecialistPools in a Location. Args: @@ -399,8 +416,10 @@ async def list_specialist_pools(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.ListSpecialistPoolsRequest(request) @@ -421,39 +440,30 @@ async def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListSpecialistPoolsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -500,8 +510,10 @@ async def delete_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.DeleteSpecialistPoolRequest(request) @@ -522,18 +534,11 @@ async def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -546,15 +551,16 @@ async def delete_specialist_pool(self, # Done; return the response. return response - async def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates a SpecialistPool. Args: @@ -601,8 +607,10 @@ async def update_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.UpdateSpecialistPoolRequest(request) @@ -625,18 +633,13 @@ async def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -650,21 +653,14 @@ async def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceAsyncClient', -) +__all__ = ("SpecialistPoolServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index 813d6413ff..cde21b3720 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -54,13 +54,16 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport - _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport + _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +120,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +155,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: SpecialistPoolServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,77 +171,88 @@ def transport(self) -> SpecialistPoolServiceTransport: return self._transport @staticmethod - def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: + def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str,str]: + def parse_specialist_pool_path(path: str) -> Dict[str, str]: """Parse a specialist_pool path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -283,7 +296,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -293,7 +308,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -305,7 +322,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -317,8 +336,10 @@ def __init__(self, *, if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -337,15 +358,16 @@ def __init__(self, *, client_info=client_info, ) - def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -393,8 +415,10 @@ def create_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.CreateSpecialistPoolRequest. @@ -418,18 +442,11 @@ def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -442,14 +459,15 @@ def create_specialist_pool(self, # Done; return the response. return response - def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -491,8 +509,10 @@ def get_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.GetSpecialistPoolRequest. @@ -514,30 +534,24 @@ def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -573,8 +587,10 @@ def list_specialist_pools(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.ListSpecialistPoolsRequest. @@ -596,39 +612,30 @@ def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -675,8 +682,10 @@ def delete_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.DeleteSpecialistPoolRequest. @@ -698,18 +707,11 @@ def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -722,15 +724,16 @@ def delete_specialist_pool(self, # Done; return the response. return response - def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -777,8 +780,10 @@ def update_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.UpdateSpecialistPoolRequest. @@ -802,18 +807,13 @@ def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -827,21 +827,14 @@ def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceClient', -) +__all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index 6b5d115c82..976bcf55b8 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service @@ -38,12 +47,15 @@ class ListSpecialistPoolsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: yield from page.specialist_pools def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListSpecialistPoolsAsyncPager: @@ -97,12 +109,17 @@ class ListSpecialistPoolsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +141,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + async def pages( + self, + ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +159,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py index 80de7b209f..1bb2fbf22a 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py @@ -24,12 +24,14 @@ # Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] -_transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport -_transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[SpecialistPoolServiceTransport]] +_transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport +_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( - 'SpecialistPoolServiceTransport', - 'SpecialistPoolServiceGrpcTransport', - 'SpecialistPoolServiceGrpcAsyncIOTransport', + "SpecialistPoolServiceTransport", + "SpecialistPoolServiceGrpcTransport", + "SpecialistPoolServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index 43c7e87f16..48ee079a5c 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,29 +34,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -79,8 +79,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -89,17 +89,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -113,9 +115,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_specialist_pool: gapic_v1.method.wrap_method( - self.get_specialist_pool, - default_timeout=5.0, - client_info=client_info, + self.get_specialist_pool, default_timeout=5.0, client_info=client_info, ), self.list_specialist_pools: gapic_v1.method.wrap_method( self.list_specialist_pools, @@ -132,7 +132,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), - } @property @@ -141,51 +140,55 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - typing.Union[ - specialist_pool.SpecialistPool, - typing.Awaitable[specialist_pool.SpecialistPool] - ]]: + def get_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool], + ], + ]: raise NotImplementedError() @property - def list_specialist_pools(self) -> typing.Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - typing.Union[ - specialist_pool_service.ListSpecialistPoolsResponse, - typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] - ]]: + def list_specialist_pools( + self, + ) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ], + ]: raise NotImplementedError() @property - def delete_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def update_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def update_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'SpecialistPoolServiceTransport', -) +__all__ = ("SpecialistPoolServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 256765e7eb..c1f9300de8 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,21 +51,24 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -177,13 +180,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -216,7 +221,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -234,17 +239,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_specialist_pool(self) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - operations.Operation]: + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -259,18 +264,21 @@ def create_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_specialist_pool' not in self._stubs: - self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_specialist_pool'] + return self._stubs["create_specialist_pool"] @property - def get_specialist_pool(self) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool]: + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + specialist_pool.SpecialistPool, + ]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -285,18 +293,21 @@ def get_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_specialist_pool' not in self._stubs: - self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs['get_specialist_pool'] + return self._stubs["get_specialist_pool"] @property - def list_specialist_pools(self) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse]: + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse, + ]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -311,18 +322,20 @@ def list_specialist_pools(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_specialist_pools' not in self._stubs: - self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs['list_specialist_pools'] + return self._stubs["list_specialist_pools"] @property - def delete_specialist_pool(self) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - operations.Operation]: + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -338,18 +351,20 @@ def delete_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_specialist_pool' not in self._stubs: - self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_specialist_pool'] + return self._stubs["delete_specialist_pool"] @property - def update_specialist_pool(self) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - operations.Operation]: + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -364,15 +379,13 @@ def update_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_specialist_pool' not in self._stubs: - self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_specialist_pool'] + return self._stubs["update_specialist_pool"] -__all__ = ( - 'SpecialistPoolServiceGrpcTransport', -) +__all__ = ("SpecialistPoolServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index 8bf8ea2c2e..592776b792 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import specialist_pool @@ -58,13 +58,15 @@ class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -93,22 +95,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -247,9 +251,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_specialist_pool(self) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -264,18 +271,21 @@ def create_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_specialist_pool' not in self._stubs: - self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_specialist_pool'] + return self._stubs["create_specialist_pool"] @property - def get_specialist_pool(self) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - Awaitable[specialist_pool.SpecialistPool]]: + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool], + ]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -290,18 +300,21 @@ def get_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_specialist_pool' not in self._stubs: - self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs['get_specialist_pool'] + return self._stubs["get_specialist_pool"] @property - def list_specialist_pools(self) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]]: + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -316,18 +329,21 @@ def list_specialist_pools(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_specialist_pools' not in self._stubs: - self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs['list_specialist_pools'] + return self._stubs["list_specialist_pools"] @property - def delete_specialist_pool(self) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -343,18 +359,21 @@ def delete_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_specialist_pool' not in self._stubs: - self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_specialist_pool'] + return self._stubs["delete_specialist_pool"] @property - def update_specialist_pool(self) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -369,15 +388,13 @@ def update_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_specialist_pool' not in self._stubs: - self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_specialist_pool'] + return self._stubs["update_specialist_pool"] -__all__ = ( - 'SpecialistPoolServiceGrpcAsyncIOTransport', -) +__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py index 4c173a843c..5c312868f1 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import VizierServiceAsyncClient __all__ = ( - 'VizierServiceClient', - 'VizierServiceAsyncClient', + "VizierServiceClient", + "VizierServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py index 4844bd0528..4bd90a79cd 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,20 +60,34 @@ class VizierServiceAsyncClient: trial_path = staticmethod(VizierServiceClient.trial_path) parse_trial_path = staticmethod(VizierServiceClient.parse_trial_path) - common_billing_account_path = staticmethod(VizierServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(VizierServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + VizierServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + VizierServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(VizierServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(VizierServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + VizierServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(VizierServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(VizierServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + VizierServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + VizierServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(VizierServiceClient.common_project_path) - parse_common_project_path = staticmethod(VizierServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + VizierServiceClient.parse_common_project_path + ) common_location_path = staticmethod(VizierServiceClient.common_location_path) - parse_common_location_path = staticmethod(VizierServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + VizierServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -116,14 +130,18 @@ def transport(self) -> VizierServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(VizierServiceClient).get_transport_class, type(VizierServiceClient)) + get_transport_class = functools.partial( + type(VizierServiceClient).get_transport_class, type(VizierServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, VizierServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, VizierServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the vizier service client. Args: @@ -162,18 +180,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_study(self, - request: vizier_service.CreateStudyRequest = None, - *, - parent: str = None, - study: gca_study.Study = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_study.Study: + async def create_study( + self, + request: vizier_service.CreateStudyRequest = None, + *, + parent: str = None, + study: gca_study.Study = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_study.Study: r"""Creates a Study. A resource name will be generated after creation of the Study. @@ -212,8 +230,10 @@ async def create_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.CreateStudyRequest(request) @@ -236,30 +256,24 @@ async def create_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_study(self, - request: vizier_service.GetStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + async def get_study( + self, + request: vizier_service.GetStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Gets a Study by name. Args: @@ -289,8 +303,10 @@ async def get_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.GetStudyRequest(request) @@ -311,30 +327,24 @@ async def get_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_studies(self, - request: vizier_service.ListStudiesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListStudiesAsyncPager: + async def list_studies( + self, + request: vizier_service.ListStudiesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListStudiesAsyncPager: r"""Lists all the studies in a region for an associated project. @@ -371,8 +381,10 @@ async def list_studies(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.ListStudiesRequest(request) @@ -393,39 +405,30 @@ async def list_studies(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListStudiesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_study(self, - request: vizier_service.DeleteStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def delete_study( + self, + request: vizier_service.DeleteStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Study. Args: @@ -452,8 +455,10 @@ async def delete_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.DeleteStudyRequest(request) @@ -474,27 +479,23 @@ async def delete_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - async def lookup_study(self, - request: vizier_service.LookupStudyRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + async def lookup_study( + self, + request: vizier_service.LookupStudyRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Looks a study up using the user-defined display_name field instead of the fully qualified resource name. @@ -526,8 +527,10 @@ async def lookup_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.LookupStudyRequest(request) @@ -548,29 +551,23 @@ async def lookup_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def suggest_trials(self, - request: vizier_service.SuggestTrialsRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def suggest_trials( + self, + request: vizier_service.SuggestTrialsRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Adds one or more Trials to a Study, with parameter values suggested by AI Platform Vizier. Returns a long-running operation associated with the generation of Trial suggestions. @@ -613,18 +610,11 @@ async def suggest_trials(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -637,15 +627,16 @@ async def suggest_trials(self, # Done; return the response. return response - async def create_trial(self, - request: vizier_service.CreateTrialRequest = None, - *, - parent: str = None, - trial: study.Trial = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def create_trial( + self, + request: vizier_service.CreateTrialRequest = None, + *, + parent: str = None, + trial: study.Trial = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a user provided Trial to a Study. Args: @@ -686,8 +677,10 @@ async def create_trial(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.CreateTrialRequest(request) @@ -710,30 +703,24 @@ async def create_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_trial(self, - request: vizier_service.GetTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def get_trial( + self, + request: vizier_service.GetTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Gets a Trial. Args: @@ -768,8 +755,10 @@ async def get_trial(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.GetTrialRequest(request) @@ -790,30 +779,24 @@ async def get_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_trials(self, - request: vizier_service.ListTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrialsAsyncPager: + async def list_trials( + self, + request: vizier_service.ListTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrialsAsyncPager: r"""Lists the Trials associated with a Study. Args: @@ -849,8 +832,10 @@ async def list_trials(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.ListTrialsRequest(request) @@ -871,38 +856,29 @@ async def list_trials(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListTrialsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def add_trial_measurement(self, - request: vizier_service.AddTrialMeasurementRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def add_trial_measurement( + self, + request: vizier_service.AddTrialMeasurementRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a measurement of the objective metrics to a Trial. This measurement is assumed to have been taken before the Trial is complete. @@ -942,29 +918,25 @@ async def add_trial_measurement(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('trial_name', request.trial_name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def complete_trial(self, - request: vizier_service.CompleteTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def complete_trial( + self, + request: vizier_service.CompleteTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Marks a Trial as complete. Args: @@ -1002,30 +974,24 @@ async def complete_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_trial(self, - request: vizier_service.DeleteTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def delete_trial( + self, + request: vizier_service.DeleteTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Trial. Args: @@ -1051,8 +1017,10 @@ async def delete_trial(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.DeleteTrialRequest(request) @@ -1073,26 +1041,22 @@ async def delete_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - async def check_trial_early_stopping_state(self, - request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def check_trial_early_stopping_state( + self, + request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Checks whether a Trial should stop or not. Returns a long-running operation. When the operation is successful, it will contain a @@ -1134,18 +1098,13 @@ async def check_trial_early_stopping_state(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('trial_name', request.trial_name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1158,13 +1117,14 @@ async def check_trial_early_stopping_state(self, # Done; return the response. return response - async def stop_trial(self, - request: vizier_service.StopTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + async def stop_trial( + self, + request: vizier_service.StopTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Stops a Trial. Args: @@ -1202,30 +1162,24 @@ async def stop_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_optimal_trials(self, - request: vizier_service.ListOptimalTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> vizier_service.ListOptimalTrialsResponse: + async def list_optimal_trials( + self, + request: vizier_service.ListOptimalTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> vizier_service.ListOptimalTrialsResponse: r"""Lists the pareto-optimal Trials for multi-objective Study or the optimal Trials for single-objective Study. The definition of pareto-optimal can be checked in wiki page. @@ -1260,8 +1214,10 @@ async def list_optimal_trials(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = vizier_service.ListOptimalTrialsRequest(request) @@ -1282,38 +1238,24 @@ async def list_optimal_trials(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'VizierServiceAsyncClient', -) +__all__ = ("VizierServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py index 13587919b9..85e381323d 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -52,13 +52,12 @@ class VizierServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[VizierServiceTransport]] - _transport_registry['grpc'] = VizierServiceGrpcTransport - _transport_registry['grpc_asyncio'] = VizierServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = VizierServiceGrpcTransport + _transport_registry["grpc_asyncio"] = VizierServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[VizierServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[VizierServiceTransport]: """Return an appropriate transport class. Args: @@ -113,7 +112,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -148,9 +147,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: VizierServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -165,99 +163,120 @@ def transport(self) -> VizierServiceTransport: return self._transport @staticmethod - def custom_job_path(project: str,location: str,custom_job: str,) -> str: + def custom_job_path(project: str, location: str, custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str,str]: + def parse_custom_job_path(path: str) -> Dict[str, str]: """Parse a custom_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def study_path(project: str,location: str,study: str,) -> str: + def study_path(project: str, location: str, study: str,) -> str: """Return a fully-qualified study string.""" - return "projects/{project}/locations/{location}/studies/{study}".format(project=project, location=location, study=study, ) + return "projects/{project}/locations/{location}/studies/{study}".format( + project=project, location=location, study=study, + ) @staticmethod - def parse_study_path(path: str) -> Dict[str,str]: + def parse_study_path(path: str) -> Dict[str, str]: """Parse a study path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str,location: str,study: str,trial: str,) -> str: + def trial_path(project: str, location: str, study: str, trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) @staticmethod - def parse_trial_path(path: str) -> Dict[str,str]: + def parse_trial_path(path: str) -> Dict[str, str]: """Parse a trial path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, VizierServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, VizierServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the vizier service client. Args: @@ -301,7 +320,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -311,7 +332,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +346,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -335,8 +360,10 @@ def __init__(self, *, if isinstance(transport, VizierServiceTransport): # transport is a VizierServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -355,15 +382,16 @@ def __init__(self, *, client_info=client_info, ) - def create_study(self, - request: vizier_service.CreateStudyRequest = None, - *, - parent: str = None, - study: gca_study.Study = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_study.Study: + def create_study( + self, + request: vizier_service.CreateStudyRequest = None, + *, + parent: str = None, + study: gca_study.Study = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_study.Study: r"""Creates a Study. A resource name will be generated after creation of the Study. @@ -402,8 +430,10 @@ def create_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.CreateStudyRequest. @@ -427,30 +457,24 @@ def create_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_study(self, - request: vizier_service.GetStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + def get_study( + self, + request: vizier_service.GetStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Gets a Study by name. Args: @@ -480,8 +504,10 @@ def get_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.GetStudyRequest. @@ -503,30 +529,24 @@ def get_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_studies(self, - request: vizier_service.ListStudiesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListStudiesPager: + def list_studies( + self, + request: vizier_service.ListStudiesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListStudiesPager: r"""Lists all the studies in a region for an associated project. @@ -563,8 +583,10 @@ def list_studies(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.ListStudiesRequest. @@ -586,39 +608,30 @@ def list_studies(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListStudiesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_study(self, - request: vizier_service.DeleteStudyRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def delete_study( + self, + request: vizier_service.DeleteStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Study. Args: @@ -645,8 +658,10 @@ def delete_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.DeleteStudyRequest. @@ -668,27 +683,23 @@ def delete_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def lookup_study(self, - request: vizier_service.LookupStudyRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Study: + def lookup_study( + self, + request: vizier_service.LookupStudyRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: r"""Looks a study up using the user-defined display_name field instead of the fully qualified resource name. @@ -720,8 +731,10 @@ def lookup_study(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.LookupStudyRequest. @@ -743,29 +756,23 @@ def lookup_study(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def suggest_trials(self, - request: vizier_service.SuggestTrialsRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def suggest_trials( + self, + request: vizier_service.SuggestTrialsRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Adds one or more Trials to a Study, with parameter values suggested by AI Platform Vizier. Returns a long-running operation associated with the generation of Trial suggestions. @@ -809,18 +816,11 @@ def suggest_trials(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation.from_gapic( @@ -833,15 +833,16 @@ def suggest_trials(self, # Done; return the response. return response - def create_trial(self, - request: vizier_service.CreateTrialRequest = None, - *, - parent: str = None, - trial: study.Trial = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def create_trial( + self, + request: vizier_service.CreateTrialRequest = None, + *, + parent: str = None, + trial: study.Trial = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a user provided Trial to a Study. Args: @@ -882,8 +883,10 @@ def create_trial(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.CreateTrialRequest. @@ -907,30 +910,24 @@ def create_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_trial(self, - request: vizier_service.GetTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def get_trial( + self, + request: vizier_service.GetTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Gets a Trial. Args: @@ -965,8 +962,10 @@ def get_trial(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.GetTrialRequest. @@ -988,30 +987,24 @@ def get_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_trials(self, - request: vizier_service.ListTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrialsPager: + def list_trials( + self, + request: vizier_service.ListTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrialsPager: r"""Lists the Trials associated with a Study. Args: @@ -1047,8 +1040,10 @@ def list_trials(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.ListTrialsRequest. @@ -1070,38 +1065,29 @@ def list_trials(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrialsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def add_trial_measurement(self, - request: vizier_service.AddTrialMeasurementRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def add_trial_measurement( + self, + request: vizier_service.AddTrialMeasurementRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Adds a measurement of the objective metrics to a Trial. This measurement is assumed to have been taken before the Trial is complete. @@ -1142,29 +1128,25 @@ def add_trial_measurement(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('trial_name', request.trial_name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def complete_trial(self, - request: vizier_service.CompleteTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def complete_trial( + self, + request: vizier_service.CompleteTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Marks a Trial as complete. Args: @@ -1203,30 +1185,24 @@ def complete_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_trial(self, - request: vizier_service.DeleteTrialRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def delete_trial( + self, + request: vizier_service.DeleteTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a Trial. Args: @@ -1252,8 +1228,10 @@ def delete_trial(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.DeleteTrialRequest. @@ -1275,26 +1253,22 @@ def delete_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def check_trial_early_stopping_state(self, - request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def check_trial_early_stopping_state( + self, + request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Checks whether a Trial should stop or not. Returns a long-running operation. When the operation is successful, it will contain a @@ -1332,23 +1306,20 @@ def check_trial_early_stopping_state(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.check_trial_early_stopping_state] + rpc = self._transport._wrapped_methods[ + self._transport.check_trial_early_stopping_state + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('trial_name', request.trial_name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation.from_gapic( @@ -1361,13 +1332,14 @@ def check_trial_early_stopping_state(self, # Done; return the response. return response - def stop_trial(self, - request: vizier_service.StopTrialRequest = None, - *, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> study.Trial: + def stop_trial( + self, + request: vizier_service.StopTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: r"""Stops a Trial. Args: @@ -1406,30 +1378,24 @@ def stop_trial(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_optimal_trials(self, - request: vizier_service.ListOptimalTrialsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> vizier_service.ListOptimalTrialsResponse: + def list_optimal_trials( + self, + request: vizier_service.ListOptimalTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> vizier_service.ListOptimalTrialsResponse: r"""Lists the pareto-optimal Trials for multi-objective Study or the optimal Trials for single-objective Study. The definition of pareto-optimal can be checked in wiki page. @@ -1464,8 +1430,10 @@ def list_optimal_trials(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a vizier_service.ListOptimalTrialsRequest. @@ -1487,38 +1455,24 @@ def list_optimal_trials(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'VizierServiceClient', -) +__all__ = ("VizierServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py index 5affed052e..c6e4fcdf63 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import study from google.cloud.aiplatform_v1beta1.types import vizier_service @@ -38,12 +47,15 @@ class ListStudiesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., vizier_service.ListStudiesResponse], - request: vizier_service.ListStudiesRequest, - response: vizier_service.ListStudiesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., vizier_service.ListStudiesResponse], + request: vizier_service.ListStudiesRequest, + response: vizier_service.ListStudiesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[study.Study]: yield from page.studies def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListStudiesAsyncPager: @@ -97,12 +109,15 @@ class ListStudiesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[vizier_service.ListStudiesResponse]], - request: vizier_service.ListStudiesRequest, - response: vizier_service.ListStudiesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[vizier_service.ListStudiesResponse]], + request: vizier_service.ListStudiesRequest, + response: vizier_service.ListStudiesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -140,7 +155,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListTrialsPager: @@ -160,12 +175,15 @@ class ListTrialsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., vizier_service.ListTrialsResponse], - request: vizier_service.ListTrialsRequest, - response: vizier_service.ListTrialsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., vizier_service.ListTrialsResponse], + request: vizier_service.ListTrialsRequest, + response: vizier_service.ListTrialsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -199,7 +217,7 @@ def __iter__(self) -> Iterable[study.Trial]: yield from page.trials def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListTrialsAsyncPager: @@ -219,12 +237,15 @@ class ListTrialsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[vizier_service.ListTrialsResponse]], - request: vizier_service.ListTrialsRequest, - response: vizier_service.ListTrialsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[vizier_service.ListTrialsResponse]], + request: vizier_service.ListTrialsRequest, + response: vizier_service.ListTrialsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -262,4 +283,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py index de1a35ae04..3ed347a603 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[VizierServiceTransport]] -_transport_registry['grpc'] = VizierServiceGrpcTransport -_transport_registry['grpc_asyncio'] = VizierServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = VizierServiceGrpcTransport +_transport_registry["grpc_asyncio"] = VizierServiceGrpcAsyncIOTransport __all__ = ( - 'VizierServiceTransport', - 'VizierServiceGrpcTransport', - 'VizierServiceGrpcAsyncIOTransport', + "VizierServiceTransport", + "VizierServiceGrpcTransport", + "VizierServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py index a6a5651b34..f09cd934b7 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class VizierServiceTransport(abc.ABC): """Abstract transport class for VizierService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,8 +81,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -91,17 +91,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -110,49 +112,31 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_study: gapic_v1.method.wrap_method( - self.create_study, - default_timeout=5.0, - client_info=client_info, + self.create_study, default_timeout=5.0, client_info=client_info, ), self.get_study: gapic_v1.method.wrap_method( - self.get_study, - default_timeout=5.0, - client_info=client_info, + self.get_study, default_timeout=5.0, client_info=client_info, ), self.list_studies: gapic_v1.method.wrap_method( - self.list_studies, - default_timeout=5.0, - client_info=client_info, + self.list_studies, default_timeout=5.0, client_info=client_info, ), self.delete_study: gapic_v1.method.wrap_method( - self.delete_study, - default_timeout=5.0, - client_info=client_info, + self.delete_study, default_timeout=5.0, client_info=client_info, ), self.lookup_study: gapic_v1.method.wrap_method( - self.lookup_study, - default_timeout=5.0, - client_info=client_info, + self.lookup_study, default_timeout=5.0, client_info=client_info, ), self.suggest_trials: gapic_v1.method.wrap_method( - self.suggest_trials, - default_timeout=5.0, - client_info=client_info, + self.suggest_trials, default_timeout=5.0, client_info=client_info, ), self.create_trial: gapic_v1.method.wrap_method( - self.create_trial, - default_timeout=5.0, - client_info=client_info, + self.create_trial, default_timeout=5.0, client_info=client_info, ), self.get_trial: gapic_v1.method.wrap_method( - self.get_trial, - default_timeout=5.0, - client_info=client_info, + self.get_trial, default_timeout=5.0, client_info=client_info, ), self.list_trials: gapic_v1.method.wrap_method( - self.list_trials, - default_timeout=5.0, - client_info=client_info, + self.list_trials, default_timeout=5.0, client_info=client_info, ), self.add_trial_measurement: gapic_v1.method.wrap_method( self.add_trial_measurement, @@ -160,14 +144,10 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.complete_trial: gapic_v1.method.wrap_method( - self.complete_trial, - default_timeout=5.0, - client_info=client_info, + self.complete_trial, default_timeout=5.0, client_info=client_info, ), self.delete_trial: gapic_v1.method.wrap_method( - self.delete_trial, - default_timeout=5.0, - client_info=client_info, + self.delete_trial, default_timeout=5.0, client_info=client_info, ), self.check_trial_early_stopping_state: gapic_v1.method.wrap_method( self.check_trial_early_stopping_state, @@ -175,16 +155,11 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.stop_trial: gapic_v1.method.wrap_method( - self.stop_trial, - default_timeout=5.0, - client_info=client_info, + self.stop_trial, default_timeout=5.0, client_info=client_info, ), self.list_optimal_trials: gapic_v1.method.wrap_method( - self.list_optimal_trials, - default_timeout=5.0, - client_info=client_info, + self.list_optimal_trials, default_timeout=5.0, client_info=client_info, ), - } @property @@ -193,141 +168,148 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_study(self) -> typing.Callable[ - [vizier_service.CreateStudyRequest], - typing.Union[ - gca_study.Study, - typing.Awaitable[gca_study.Study] - ]]: + def create_study( + self, + ) -> typing.Callable[ + [vizier_service.CreateStudyRequest], + typing.Union[gca_study.Study, typing.Awaitable[gca_study.Study]], + ]: raise NotImplementedError() @property - def get_study(self) -> typing.Callable[ - [vizier_service.GetStudyRequest], - typing.Union[ - study.Study, - typing.Awaitable[study.Study] - ]]: + def get_study( + self, + ) -> typing.Callable[ + [vizier_service.GetStudyRequest], + typing.Union[study.Study, typing.Awaitable[study.Study]], + ]: raise NotImplementedError() @property - def list_studies(self) -> typing.Callable[ - [vizier_service.ListStudiesRequest], - typing.Union[ - vizier_service.ListStudiesResponse, - typing.Awaitable[vizier_service.ListStudiesResponse] - ]]: + def list_studies( + self, + ) -> typing.Callable[ + [vizier_service.ListStudiesRequest], + typing.Union[ + vizier_service.ListStudiesResponse, + typing.Awaitable[vizier_service.ListStudiesResponse], + ], + ]: raise NotImplementedError() @property - def delete_study(self) -> typing.Callable[ - [vizier_service.DeleteStudyRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def delete_study( + self, + ) -> typing.Callable[ + [vizier_service.DeleteStudyRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def lookup_study(self) -> typing.Callable[ - [vizier_service.LookupStudyRequest], - typing.Union[ - study.Study, - typing.Awaitable[study.Study] - ]]: + def lookup_study( + self, + ) -> typing.Callable[ + [vizier_service.LookupStudyRequest], + typing.Union[study.Study, typing.Awaitable[study.Study]], + ]: raise NotImplementedError() @property - def suggest_trials(self) -> typing.Callable[ - [vizier_service.SuggestTrialsRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def suggest_trials( + self, + ) -> typing.Callable[ + [vizier_service.SuggestTrialsRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def create_trial(self) -> typing.Callable[ - [vizier_service.CreateTrialRequest], - typing.Union[ - study.Trial, - typing.Awaitable[study.Trial] - ]]: + def create_trial( + self, + ) -> typing.Callable[ + [vizier_service.CreateTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: raise NotImplementedError() @property - def get_trial(self) -> typing.Callable[ - [vizier_service.GetTrialRequest], - typing.Union[ - study.Trial, - typing.Awaitable[study.Trial] - ]]: + def get_trial( + self, + ) -> typing.Callable[ + [vizier_service.GetTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: raise NotImplementedError() @property - def list_trials(self) -> typing.Callable[ - [vizier_service.ListTrialsRequest], - typing.Union[ - vizier_service.ListTrialsResponse, - typing.Awaitable[vizier_service.ListTrialsResponse] - ]]: + def list_trials( + self, + ) -> typing.Callable[ + [vizier_service.ListTrialsRequest], + typing.Union[ + vizier_service.ListTrialsResponse, + typing.Awaitable[vizier_service.ListTrialsResponse], + ], + ]: raise NotImplementedError() @property - def add_trial_measurement(self) -> typing.Callable[ - [vizier_service.AddTrialMeasurementRequest], - typing.Union[ - study.Trial, - typing.Awaitable[study.Trial] - ]]: + def add_trial_measurement( + self, + ) -> typing.Callable[ + [vizier_service.AddTrialMeasurementRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: raise NotImplementedError() @property - def complete_trial(self) -> typing.Callable[ - [vizier_service.CompleteTrialRequest], - typing.Union[ - study.Trial, - typing.Awaitable[study.Trial] - ]]: + def complete_trial( + self, + ) -> typing.Callable[ + [vizier_service.CompleteTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: raise NotImplementedError() @property - def delete_trial(self) -> typing.Callable[ - [vizier_service.DeleteTrialRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def delete_trial( + self, + ) -> typing.Callable[ + [vizier_service.DeleteTrialRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def check_trial_early_stopping_state(self) -> typing.Callable[ - [vizier_service.CheckTrialEarlyStoppingStateRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def check_trial_early_stopping_state( + self, + ) -> typing.Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def stop_trial(self) -> typing.Callable[ - [vizier_service.StopTrialRequest], - typing.Union[ - study.Trial, - typing.Awaitable[study.Trial] - ]]: + def stop_trial( + self, + ) -> typing.Callable[ + [vizier_service.StopTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: raise NotImplementedError() @property - def list_optimal_trials(self) -> typing.Callable[ - [vizier_service.ListOptimalTrialsRequest], - typing.Union[ - vizier_service.ListOptimalTrialsResponse, - typing.Awaitable[vizier_service.ListOptimalTrialsResponse] - ]]: + def list_optimal_trials( + self, + ) -> typing.Callable[ + [vizier_service.ListOptimalTrialsRequest], + typing.Union[ + vizier_service.ListOptimalTrialsResponse, + typing.Awaitable[vizier_service.ListOptimalTrialsResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'VizierServiceTransport', -) +__all__ = ("VizierServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py index a9e3db2e54..c46d167f87 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,21 +51,24 @@ class VizierServiceGrpcTransport(VizierServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -177,13 +180,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -216,7 +221,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -234,17 +239,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_study(self) -> Callable[ - [vizier_service.CreateStudyRequest], - gca_study.Study]: + def create_study( + self, + ) -> Callable[[vizier_service.CreateStudyRequest], gca_study.Study]: r"""Return a callable for the create study method over gRPC. Creates a Study. A resource name will be generated @@ -260,18 +263,16 @@ def create_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_study' not in self._stubs: - self._stubs['create_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy', + if "create_study" not in self._stubs: + self._stubs["create_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy", request_serializer=vizier_service.CreateStudyRequest.serialize, response_deserializer=gca_study.Study.deserialize, ) - return self._stubs['create_study'] + return self._stubs["create_study"] @property - def get_study(self) -> Callable[ - [vizier_service.GetStudyRequest], - study.Study]: + def get_study(self) -> Callable[[vizier_service.GetStudyRequest], study.Study]: r"""Return a callable for the get study method over gRPC. Gets a Study by name. @@ -286,18 +287,20 @@ def get_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_study' not in self._stubs: - self._stubs['get_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/GetStudy', + if "get_study" not in self._stubs: + self._stubs["get_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetStudy", request_serializer=vizier_service.GetStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs['get_study'] + return self._stubs["get_study"] @property - def list_studies(self) -> Callable[ - [vizier_service.ListStudiesRequest], - vizier_service.ListStudiesResponse]: + def list_studies( + self, + ) -> Callable[ + [vizier_service.ListStudiesRequest], vizier_service.ListStudiesResponse + ]: r"""Return a callable for the list studies method over gRPC. Lists all the studies in a region for an associated @@ -313,18 +316,18 @@ def list_studies(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_studies' not in self._stubs: - self._stubs['list_studies'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/ListStudies', + if "list_studies" not in self._stubs: + self._stubs["list_studies"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListStudies", request_serializer=vizier_service.ListStudiesRequest.serialize, response_deserializer=vizier_service.ListStudiesResponse.deserialize, ) - return self._stubs['list_studies'] + return self._stubs["list_studies"] @property - def delete_study(self) -> Callable[ - [vizier_service.DeleteStudyRequest], - empty.Empty]: + def delete_study( + self, + ) -> Callable[[vizier_service.DeleteStudyRequest], empty.Empty]: r"""Return a callable for the delete study method over gRPC. Deletes a Study. @@ -339,18 +342,18 @@ def delete_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_study' not in self._stubs: - self._stubs['delete_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy', + if "delete_study" not in self._stubs: + self._stubs["delete_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy", request_serializer=vizier_service.DeleteStudyRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['delete_study'] + return self._stubs["delete_study"] @property - def lookup_study(self) -> Callable[ - [vizier_service.LookupStudyRequest], - study.Study]: + def lookup_study( + self, + ) -> Callable[[vizier_service.LookupStudyRequest], study.Study]: r"""Return a callable for the lookup study method over gRPC. Looks a study up using the user-defined display_name field @@ -366,18 +369,18 @@ def lookup_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'lookup_study' not in self._stubs: - self._stubs['lookup_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy', + if "lookup_study" not in self._stubs: + self._stubs["lookup_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy", request_serializer=vizier_service.LookupStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs['lookup_study'] + return self._stubs["lookup_study"] @property - def suggest_trials(self) -> Callable[ - [vizier_service.SuggestTrialsRequest], - operations.Operation]: + def suggest_trials( + self, + ) -> Callable[[vizier_service.SuggestTrialsRequest], operations.Operation]: r"""Return a callable for the suggest trials method over gRPC. Adds one or more Trials to a Study, with parameter values @@ -396,18 +399,18 @@ def suggest_trials(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'suggest_trials' not in self._stubs: - self._stubs['suggest_trials'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials', + if "suggest_trials" not in self._stubs: + self._stubs["suggest_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials", request_serializer=vizier_service.SuggestTrialsRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['suggest_trials'] + return self._stubs["suggest_trials"] @property - def create_trial(self) -> Callable[ - [vizier_service.CreateTrialRequest], - study.Trial]: + def create_trial( + self, + ) -> Callable[[vizier_service.CreateTrialRequest], study.Trial]: r"""Return a callable for the create trial method over gRPC. Adds a user provided Trial to a Study. @@ -422,18 +425,16 @@ def create_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_trial' not in self._stubs: - self._stubs['create_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial', + if "create_trial" not in self._stubs: + self._stubs["create_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial", request_serializer=vizier_service.CreateTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['create_trial'] + return self._stubs["create_trial"] @property - def get_trial(self) -> Callable[ - [vizier_service.GetTrialRequest], - study.Trial]: + def get_trial(self) -> Callable[[vizier_service.GetTrialRequest], study.Trial]: r"""Return a callable for the get trial method over gRPC. Gets a Trial. @@ -448,18 +449,20 @@ def get_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_trial' not in self._stubs: - self._stubs['get_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/GetTrial', + if "get_trial" not in self._stubs: + self._stubs["get_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetTrial", request_serializer=vizier_service.GetTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['get_trial'] + return self._stubs["get_trial"] @property - def list_trials(self) -> Callable[ - [vizier_service.ListTrialsRequest], - vizier_service.ListTrialsResponse]: + def list_trials( + self, + ) -> Callable[ + [vizier_service.ListTrialsRequest], vizier_service.ListTrialsResponse + ]: r"""Return a callable for the list trials method over gRPC. Lists the Trials associated with a Study. @@ -474,18 +477,18 @@ def list_trials(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_trials' not in self._stubs: - self._stubs['list_trials'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/ListTrials', + if "list_trials" not in self._stubs: + self._stubs["list_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListTrials", request_serializer=vizier_service.ListTrialsRequest.serialize, response_deserializer=vizier_service.ListTrialsResponse.deserialize, ) - return self._stubs['list_trials'] + return self._stubs["list_trials"] @property - def add_trial_measurement(self) -> Callable[ - [vizier_service.AddTrialMeasurementRequest], - study.Trial]: + def add_trial_measurement( + self, + ) -> Callable[[vizier_service.AddTrialMeasurementRequest], study.Trial]: r"""Return a callable for the add trial measurement method over gRPC. Adds a measurement of the objective metrics to a @@ -502,18 +505,18 @@ def add_trial_measurement(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_trial_measurement' not in self._stubs: - self._stubs['add_trial_measurement'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement', + if "add_trial_measurement" not in self._stubs: + self._stubs["add_trial_measurement"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement", request_serializer=vizier_service.AddTrialMeasurementRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['add_trial_measurement'] + return self._stubs["add_trial_measurement"] @property - def complete_trial(self) -> Callable[ - [vizier_service.CompleteTrialRequest], - study.Trial]: + def complete_trial( + self, + ) -> Callable[[vizier_service.CompleteTrialRequest], study.Trial]: r"""Return a callable for the complete trial method over gRPC. Marks a Trial as complete. @@ -528,18 +531,18 @@ def complete_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'complete_trial' not in self._stubs: - self._stubs['complete_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial', + if "complete_trial" not in self._stubs: + self._stubs["complete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial", request_serializer=vizier_service.CompleteTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['complete_trial'] + return self._stubs["complete_trial"] @property - def delete_trial(self) -> Callable[ - [vizier_service.DeleteTrialRequest], - empty.Empty]: + def delete_trial( + self, + ) -> Callable[[vizier_service.DeleteTrialRequest], empty.Empty]: r"""Return a callable for the delete trial method over gRPC. Deletes a Trial. @@ -554,18 +557,20 @@ def delete_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_trial' not in self._stubs: - self._stubs['delete_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial', + if "delete_trial" not in self._stubs: + self._stubs["delete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial", request_serializer=vizier_service.DeleteTrialRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['delete_trial'] + return self._stubs["delete_trial"] @property - def check_trial_early_stopping_state(self) -> Callable[ - [vizier_service.CheckTrialEarlyStoppingStateRequest], - operations.Operation]: + def check_trial_early_stopping_state( + self, + ) -> Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], operations.Operation + ]: r"""Return a callable for the check trial early stopping state method over gRPC. @@ -584,18 +589,18 @@ def check_trial_early_stopping_state(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'check_trial_early_stopping_state' not in self._stubs: - self._stubs['check_trial_early_stopping_state'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState', + if "check_trial_early_stopping_state" not in self._stubs: + self._stubs[ + "check_trial_early_stopping_state" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState", request_serializer=vizier_service.CheckTrialEarlyStoppingStateRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['check_trial_early_stopping_state'] + return self._stubs["check_trial_early_stopping_state"] @property - def stop_trial(self) -> Callable[ - [vizier_service.StopTrialRequest], - study.Trial]: + def stop_trial(self) -> Callable[[vizier_service.StopTrialRequest], study.Trial]: r"""Return a callable for the stop trial method over gRPC. Stops a Trial. @@ -610,18 +615,21 @@ def stop_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'stop_trial' not in self._stubs: - self._stubs['stop_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/StopTrial', + if "stop_trial" not in self._stubs: + self._stubs["stop_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/StopTrial", request_serializer=vizier_service.StopTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['stop_trial'] + return self._stubs["stop_trial"] @property - def list_optimal_trials(self) -> Callable[ - [vizier_service.ListOptimalTrialsRequest], - vizier_service.ListOptimalTrialsResponse]: + def list_optimal_trials( + self, + ) -> Callable[ + [vizier_service.ListOptimalTrialsRequest], + vizier_service.ListOptimalTrialsResponse, + ]: r"""Return a callable for the list optimal trials method over gRPC. Lists the pareto-optimal Trials for multi-objective Study or the @@ -639,15 +647,13 @@ def list_optimal_trials(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_optimal_trials' not in self._stubs: - self._stubs['list_optimal_trials'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials', + if "list_optimal_trials" not in self._stubs: + self._stubs["list_optimal_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials", request_serializer=vizier_service.ListOptimalTrialsRequest.serialize, response_deserializer=vizier_service.ListOptimalTrialsResponse.deserialize, ) - return self._stubs['list_optimal_trials'] + return self._stubs["list_optimal_trials"] -__all__ = ( - 'VizierServiceGrpcTransport', -) +__all__ = ("VizierServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py index fedbc26b71..fc88aa444e 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import study @@ -58,13 +58,15 @@ class VizierServiceGrpcAsyncIOTransport(VizierServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -93,22 +95,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -247,9 +251,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_study(self) -> Callable[ - [vizier_service.CreateStudyRequest], - Awaitable[gca_study.Study]]: + def create_study( + self, + ) -> Callable[[vizier_service.CreateStudyRequest], Awaitable[gca_study.Study]]: r"""Return a callable for the create study method over gRPC. Creates a Study. A resource name will be generated @@ -265,18 +269,18 @@ def create_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_study' not in self._stubs: - self._stubs['create_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy', + if "create_study" not in self._stubs: + self._stubs["create_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy", request_serializer=vizier_service.CreateStudyRequest.serialize, response_deserializer=gca_study.Study.deserialize, ) - return self._stubs['create_study'] + return self._stubs["create_study"] @property - def get_study(self) -> Callable[ - [vizier_service.GetStudyRequest], - Awaitable[study.Study]]: + def get_study( + self, + ) -> Callable[[vizier_service.GetStudyRequest], Awaitable[study.Study]]: r"""Return a callable for the get study method over gRPC. Gets a Study by name. @@ -291,18 +295,21 @@ def get_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_study' not in self._stubs: - self._stubs['get_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/GetStudy', + if "get_study" not in self._stubs: + self._stubs["get_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetStudy", request_serializer=vizier_service.GetStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs['get_study'] + return self._stubs["get_study"] @property - def list_studies(self) -> Callable[ - [vizier_service.ListStudiesRequest], - Awaitable[vizier_service.ListStudiesResponse]]: + def list_studies( + self, + ) -> Callable[ + [vizier_service.ListStudiesRequest], + Awaitable[vizier_service.ListStudiesResponse], + ]: r"""Return a callable for the list studies method over gRPC. Lists all the studies in a region for an associated @@ -318,18 +325,18 @@ def list_studies(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_studies' not in self._stubs: - self._stubs['list_studies'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/ListStudies', + if "list_studies" not in self._stubs: + self._stubs["list_studies"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListStudies", request_serializer=vizier_service.ListStudiesRequest.serialize, response_deserializer=vizier_service.ListStudiesResponse.deserialize, ) - return self._stubs['list_studies'] + return self._stubs["list_studies"] @property - def delete_study(self) -> Callable[ - [vizier_service.DeleteStudyRequest], - Awaitable[empty.Empty]]: + def delete_study( + self, + ) -> Callable[[vizier_service.DeleteStudyRequest], Awaitable[empty.Empty]]: r"""Return a callable for the delete study method over gRPC. Deletes a Study. @@ -344,18 +351,18 @@ def delete_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_study' not in self._stubs: - self._stubs['delete_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy', + if "delete_study" not in self._stubs: + self._stubs["delete_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy", request_serializer=vizier_service.DeleteStudyRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['delete_study'] + return self._stubs["delete_study"] @property - def lookup_study(self) -> Callable[ - [vizier_service.LookupStudyRequest], - Awaitable[study.Study]]: + def lookup_study( + self, + ) -> Callable[[vizier_service.LookupStudyRequest], Awaitable[study.Study]]: r"""Return a callable for the lookup study method over gRPC. Looks a study up using the user-defined display_name field @@ -371,18 +378,20 @@ def lookup_study(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'lookup_study' not in self._stubs: - self._stubs['lookup_study'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy', + if "lookup_study" not in self._stubs: + self._stubs["lookup_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy", request_serializer=vizier_service.LookupStudyRequest.serialize, response_deserializer=study.Study.deserialize, ) - return self._stubs['lookup_study'] + return self._stubs["lookup_study"] @property - def suggest_trials(self) -> Callable[ - [vizier_service.SuggestTrialsRequest], - Awaitable[operations.Operation]]: + def suggest_trials( + self, + ) -> Callable[ + [vizier_service.SuggestTrialsRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the suggest trials method over gRPC. Adds one or more Trials to a Study, with parameter values @@ -401,18 +410,18 @@ def suggest_trials(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'suggest_trials' not in self._stubs: - self._stubs['suggest_trials'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials', + if "suggest_trials" not in self._stubs: + self._stubs["suggest_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials", request_serializer=vizier_service.SuggestTrialsRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['suggest_trials'] + return self._stubs["suggest_trials"] @property - def create_trial(self) -> Callable[ - [vizier_service.CreateTrialRequest], - Awaitable[study.Trial]]: + def create_trial( + self, + ) -> Callable[[vizier_service.CreateTrialRequest], Awaitable[study.Trial]]: r"""Return a callable for the create trial method over gRPC. Adds a user provided Trial to a Study. @@ -427,18 +436,18 @@ def create_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_trial' not in self._stubs: - self._stubs['create_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial', + if "create_trial" not in self._stubs: + self._stubs["create_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial", request_serializer=vizier_service.CreateTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['create_trial'] + return self._stubs["create_trial"] @property - def get_trial(self) -> Callable[ - [vizier_service.GetTrialRequest], - Awaitable[study.Trial]]: + def get_trial( + self, + ) -> Callable[[vizier_service.GetTrialRequest], Awaitable[study.Trial]]: r"""Return a callable for the get trial method over gRPC. Gets a Trial. @@ -453,18 +462,20 @@ def get_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_trial' not in self._stubs: - self._stubs['get_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/GetTrial', + if "get_trial" not in self._stubs: + self._stubs["get_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetTrial", request_serializer=vizier_service.GetTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['get_trial'] + return self._stubs["get_trial"] @property - def list_trials(self) -> Callable[ - [vizier_service.ListTrialsRequest], - Awaitable[vizier_service.ListTrialsResponse]]: + def list_trials( + self, + ) -> Callable[ + [vizier_service.ListTrialsRequest], Awaitable[vizier_service.ListTrialsResponse] + ]: r"""Return a callable for the list trials method over gRPC. Lists the Trials associated with a Study. @@ -479,18 +490,18 @@ def list_trials(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_trials' not in self._stubs: - self._stubs['list_trials'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/ListTrials', + if "list_trials" not in self._stubs: + self._stubs["list_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListTrials", request_serializer=vizier_service.ListTrialsRequest.serialize, response_deserializer=vizier_service.ListTrialsResponse.deserialize, ) - return self._stubs['list_trials'] + return self._stubs["list_trials"] @property - def add_trial_measurement(self) -> Callable[ - [vizier_service.AddTrialMeasurementRequest], - Awaitable[study.Trial]]: + def add_trial_measurement( + self, + ) -> Callable[[vizier_service.AddTrialMeasurementRequest], Awaitable[study.Trial]]: r"""Return a callable for the add trial measurement method over gRPC. Adds a measurement of the objective metrics to a @@ -507,18 +518,18 @@ def add_trial_measurement(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'add_trial_measurement' not in self._stubs: - self._stubs['add_trial_measurement'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement', + if "add_trial_measurement" not in self._stubs: + self._stubs["add_trial_measurement"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement", request_serializer=vizier_service.AddTrialMeasurementRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['add_trial_measurement'] + return self._stubs["add_trial_measurement"] @property - def complete_trial(self) -> Callable[ - [vizier_service.CompleteTrialRequest], - Awaitable[study.Trial]]: + def complete_trial( + self, + ) -> Callable[[vizier_service.CompleteTrialRequest], Awaitable[study.Trial]]: r"""Return a callable for the complete trial method over gRPC. Marks a Trial as complete. @@ -533,18 +544,18 @@ def complete_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'complete_trial' not in self._stubs: - self._stubs['complete_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial', + if "complete_trial" not in self._stubs: + self._stubs["complete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial", request_serializer=vizier_service.CompleteTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['complete_trial'] + return self._stubs["complete_trial"] @property - def delete_trial(self) -> Callable[ - [vizier_service.DeleteTrialRequest], - Awaitable[empty.Empty]]: + def delete_trial( + self, + ) -> Callable[[vizier_service.DeleteTrialRequest], Awaitable[empty.Empty]]: r"""Return a callable for the delete trial method over gRPC. Deletes a Trial. @@ -559,18 +570,21 @@ def delete_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_trial' not in self._stubs: - self._stubs['delete_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial', + if "delete_trial" not in self._stubs: + self._stubs["delete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial", request_serializer=vizier_service.DeleteTrialRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['delete_trial'] + return self._stubs["delete_trial"] @property - def check_trial_early_stopping_state(self) -> Callable[ - [vizier_service.CheckTrialEarlyStoppingStateRequest], - Awaitable[operations.Operation]]: + def check_trial_early_stopping_state( + self, + ) -> Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the check trial early stopping state method over gRPC. @@ -589,18 +603,20 @@ def check_trial_early_stopping_state(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'check_trial_early_stopping_state' not in self._stubs: - self._stubs['check_trial_early_stopping_state'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState', + if "check_trial_early_stopping_state" not in self._stubs: + self._stubs[ + "check_trial_early_stopping_state" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState", request_serializer=vizier_service.CheckTrialEarlyStoppingStateRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['check_trial_early_stopping_state'] + return self._stubs["check_trial_early_stopping_state"] @property - def stop_trial(self) -> Callable[ - [vizier_service.StopTrialRequest], - Awaitable[study.Trial]]: + def stop_trial( + self, + ) -> Callable[[vizier_service.StopTrialRequest], Awaitable[study.Trial]]: r"""Return a callable for the stop trial method over gRPC. Stops a Trial. @@ -615,18 +631,21 @@ def stop_trial(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'stop_trial' not in self._stubs: - self._stubs['stop_trial'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/StopTrial', + if "stop_trial" not in self._stubs: + self._stubs["stop_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/StopTrial", request_serializer=vizier_service.StopTrialRequest.serialize, response_deserializer=study.Trial.deserialize, ) - return self._stubs['stop_trial'] + return self._stubs["stop_trial"] @property - def list_optimal_trials(self) -> Callable[ - [vizier_service.ListOptimalTrialsRequest], - Awaitable[vizier_service.ListOptimalTrialsResponse]]: + def list_optimal_trials( + self, + ) -> Callable[ + [vizier_service.ListOptimalTrialsRequest], + Awaitable[vizier_service.ListOptimalTrialsResponse], + ]: r"""Return a callable for the list optimal trials method over gRPC. Lists the pareto-optimal Trials for multi-objective Study or the @@ -644,15 +663,13 @@ def list_optimal_trials(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_optimal_trials' not in self._stubs: - self._stubs['list_optimal_trials'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials', + if "list_optimal_trials" not in self._stubs: + self._stubs["list_optimal_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials", request_serializer=vizier_service.ListOptimalTrialsRequest.serialize, response_deserializer=vizier_service.ListOptimalTrialsResponse.deserialize, ) - return self._stubs['list_optimal_trials'] + return self._stubs["list_optimal_trials"] -__all__ = ( - 'VizierServiceGrpcAsyncIOTransport', -) +__all__ = ("VizierServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 8cc21f36ae..10ed63d9a7 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -15,24 +15,12 @@ # limitations under the License. # -from .annotation import ( - Annotation, -) -from .annotation_spec import ( - AnnotationSpec, -) -from .artifact import ( - Artifact, -) -from .batch_prediction_job import ( - BatchPredictionJob, -) -from .completion_stats import ( - CompletionStats, -) -from .context import ( - Context, -) +from .annotation import Annotation +from .annotation_spec import AnnotationSpec +from .artifact import Artifact +from .batch_prediction_job import BatchPredictionJob +from .completion_stats import CompletionStats +from .context import Context from .custom_job import ( ContainerSpec, CustomJob, @@ -41,9 +29,7 @@ Scheduling, WorkerPoolSpec, ) -from .data_item import ( - DataItem, -) +from .data_item import DataItem from .data_labeling_job import ( ActiveLearningConfig, DataLabelingJob, @@ -75,12 +61,8 @@ ListDatasetsResponse, UpdateDatasetRequest, ) -from .deployed_model_ref import ( - DeployedModelRef, -) -from .encryption_spec import ( - EncryptionSpec, -) +from .deployed_model_ref import DeployedModelRef +from .encryption_spec import EncryptionSpec from .endpoint import ( DeployedModel, Endpoint, @@ -100,15 +82,9 @@ UndeployModelResponse, UpdateEndpointRequest, ) -from .env_var import ( - EnvVar, -) -from .event import ( - Event, -) -from .execution import ( - Execution, -) +from .env_var import EnvVar +from .event import Event +from .execution import Execution from .explanation import ( Attribution, Explanation, @@ -123,15 +99,9 @@ SmoothGradConfig, XraiAttribution, ) -from .explanation_metadata import ( - ExplanationMetadata, -) -from .feature_monitoring_stats import ( - FeatureStatsAnomaly, -) -from .hyperparameter_tuning_job import ( - HyperparameterTuningJob, -) +from .explanation_metadata import ExplanationMetadata +from .feature_monitoring_stats import FeatureStatsAnomaly +from .hyperparameter_tuning_job import HyperparameterTuningJob from .io import ( BigQueryDestination, BigQuerySource, @@ -176,9 +146,7 @@ UpdateModelDeploymentMonitoringJobOperationMetadata, UpdateModelDeploymentMonitoringJobRequest, ) -from .lineage_subgraph import ( - LineageSubgraph, -) +from .lineage_subgraph import LineageSubgraph from .machine_resources import ( AutomaticResources, AutoscalingMetricSpec, @@ -188,12 +156,8 @@ MachineSpec, ResourcesConsumed, ) -from .manual_batch_tuning_parameters import ( - ManualBatchTuningParameters, -) -from .metadata_schema import ( - MetadataSchema, -) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters +from .metadata_schema import MetadataSchema from .metadata_service import ( AddContextArtifactsAndExecutionsRequest, AddContextArtifactsAndExecutionsResponse, @@ -231,12 +195,8 @@ UpdateContextRequest, UpdateExecutionRequest, ) -from .metadata_store import ( - MetadataStore, -) -from .migratable_resource import ( - MigratableResource, -) +from .metadata_store import MetadataStore +from .migratable_resource import MigratableResource from .migration_service import ( BatchMigrateResourcesOperationMetadata, BatchMigrateResourcesRequest, @@ -260,12 +220,8 @@ ModelMonitoringStatsAnomalies, ModelDeploymentMonitoringObjectiveType, ) -from .model_evaluation import ( - ModelEvaluation, -) -from .model_evaluation_slice import ( - ModelEvaluationSlice, -) +from .model_evaluation import ModelEvaluation +from .model_evaluation_slice import ModelEvaluationSlice from .model_monitoring import ( ModelMonitoringAlertConfig, ModelMonitoringObjectiveConfig, @@ -309,9 +265,7 @@ PredictRequest, PredictResponse, ) -from .specialist_pool import ( - SpecialistPool, -) +from .specialist_pool import SpecialistPool from .specialist_pool_service import ( CreateSpecialistPoolOperationMetadata, CreateSpecialistPoolRequest, @@ -336,9 +290,7 @@ TimestampSplit, TrainingPipeline, ) -from .user_action_reference import ( - UserActionReference, -) +from .user_action_reference import UserActionReference from .vizier_service import ( AddTrialMeasurementRequest, CheckTrialEarlyStoppingStateMetatdata, @@ -365,261 +317,261 @@ ) __all__ = ( - 'AcceleratorType', - 'Annotation', - 'AnnotationSpec', - 'Artifact', - 'BatchPredictionJob', - 'CompletionStats', - 'Context', - 'ContainerSpec', - 'CustomJob', - 'CustomJobSpec', - 'PythonPackageSpec', - 'Scheduling', - 'WorkerPoolSpec', - 'DataItem', - 'ActiveLearningConfig', - 'DataLabelingJob', - 'SampleConfig', - 'TrainingConfig', - 'Dataset', - 'ExportDataConfig', - 'ImportDataConfig', - 'CreateDatasetOperationMetadata', - 'CreateDatasetRequest', - 'DeleteDatasetRequest', - 'ExportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'GetAnnotationSpecRequest', - 'GetDatasetRequest', - 'ImportDataOperationMetadata', - 'ImportDataRequest', - 'ImportDataResponse', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'UpdateDatasetRequest', - 'DeployedModelRef', - 'EncryptionSpec', - 'DeployedModel', - 'Endpoint', - 'CreateEndpointOperationMetadata', - 'CreateEndpointRequest', - 'DeleteEndpointRequest', - 'DeployModelOperationMetadata', - 'DeployModelRequest', - 'DeployModelResponse', - 'GetEndpointRequest', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'UndeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UpdateEndpointRequest', - 'EnvVar', - 'Event', - 'Execution', - 'Attribution', - 'Explanation', - 'ExplanationMetadataOverride', - 'ExplanationParameters', - 'ExplanationSpec', - 'ExplanationSpecOverride', - 'FeatureNoiseSigma', - 'IntegratedGradientsAttribution', - 'ModelExplanation', - 'SampledShapleyAttribution', - 'SmoothGradConfig', - 'XraiAttribution', - 'ExplanationMetadata', - 'FeatureStatsAnomaly', - 'HyperparameterTuningJob', - 'BigQueryDestination', - 'BigQuerySource', - 'ContainerRegistryDestination', - 'GcsDestination', - 'GcsSource', - 'CancelBatchPredictionJobRequest', - 'CancelCustomJobRequest', - 'CancelDataLabelingJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CreateBatchPredictionJobRequest', - 'CreateCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'CreateHyperparameterTuningJobRequest', - 'CreateModelDeploymentMonitoringJobRequest', - 'DeleteBatchPredictionJobRequest', - 'DeleteCustomJobRequest', - 'DeleteDataLabelingJobRequest', - 'DeleteHyperparameterTuningJobRequest', - 'DeleteModelDeploymentMonitoringJobRequest', - 'GetBatchPredictionJobRequest', - 'GetCustomJobRequest', - 'GetDataLabelingJobRequest', - 'GetHyperparameterTuningJobRequest', - 'GetModelDeploymentMonitoringJobRequest', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'ListModelDeploymentMonitoringJobsRequest', - 'ListModelDeploymentMonitoringJobsResponse', - 'PauseModelDeploymentMonitoringJobRequest', - 'ResumeModelDeploymentMonitoringJobRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', - 'UpdateModelDeploymentMonitoringJobOperationMetadata', - 'UpdateModelDeploymentMonitoringJobRequest', - 'JobState', - 'LineageSubgraph', - 'AutomaticResources', - 'AutoscalingMetricSpec', - 'BatchDedicatedResources', - 'DedicatedResources', - 'DiskSpec', - 'MachineSpec', - 'ResourcesConsumed', - 'ManualBatchTuningParameters', - 'MetadataSchema', - 'AddContextArtifactsAndExecutionsRequest', - 'AddContextArtifactsAndExecutionsResponse', - 'AddContextChildrenRequest', - 'AddContextChildrenResponse', - 'AddExecutionEventsRequest', - 'AddExecutionEventsResponse', - 'CreateArtifactRequest', - 'CreateContextRequest', - 'CreateExecutionRequest', - 'CreateMetadataSchemaRequest', - 'CreateMetadataStoreOperationMetadata', - 'CreateMetadataStoreRequest', - 'DeleteContextRequest', - 'DeleteMetadataStoreOperationMetadata', - 'DeleteMetadataStoreRequest', - 'GetArtifactRequest', - 'GetContextRequest', - 'GetExecutionRequest', - 'GetMetadataSchemaRequest', - 'GetMetadataStoreRequest', - 'ListArtifactsRequest', - 'ListArtifactsResponse', - 'ListContextsRequest', - 'ListContextsResponse', - 'ListExecutionsRequest', - 'ListExecutionsResponse', - 'ListMetadataSchemasRequest', - 'ListMetadataSchemasResponse', - 'ListMetadataStoresRequest', - 'ListMetadataStoresResponse', - 'QueryContextLineageSubgraphRequest', - 'QueryExecutionInputsAndOutputsRequest', - 'UpdateArtifactRequest', - 'UpdateContextRequest', - 'UpdateExecutionRequest', - 'MetadataStore', - 'MigratableResource', - 'BatchMigrateResourcesOperationMetadata', - 'BatchMigrateResourcesRequest', - 'BatchMigrateResourcesResponse', - 'MigrateResourceRequest', - 'MigrateResourceResponse', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'Model', - 'ModelContainerSpec', - 'Port', - 'PredictSchemata', - 'ModelDeploymentMonitoringBigQueryTable', - 'ModelDeploymentMonitoringJob', - 'ModelDeploymentMonitoringObjectiveConfig', - 'ModelDeploymentMonitoringScheduleConfig', - 'ModelMonitoringStatsAnomalies', - 'ModelDeploymentMonitoringObjectiveType', - 'ModelEvaluation', - 'ModelEvaluationSlice', - 'ModelMonitoringAlertConfig', - 'ModelMonitoringObjectiveConfig', - 'SamplingStrategy', - 'ThresholdConfig', - 'DeleteModelRequest', - 'ExportModelOperationMetadata', - 'ExportModelRequest', - 'ExportModelResponse', - 'GetModelEvaluationRequest', - 'GetModelEvaluationSliceRequest', - 'GetModelRequest', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'ListModelsRequest', - 'ListModelsResponse', - 'UpdateModelRequest', - 'UploadModelOperationMetadata', - 'UploadModelRequest', - 'UploadModelResponse', - 'DeleteOperationMetadata', - 'GenericOperationMetadata', - 'CancelTrainingPipelineRequest', - 'CreateTrainingPipelineRequest', - 'DeleteTrainingPipelineRequest', - 'GetTrainingPipelineRequest', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'PipelineState', - 'ExplainRequest', - 'ExplainResponse', - 'PredictRequest', - 'PredictResponse', - 'SpecialistPool', - 'CreateSpecialistPoolOperationMetadata', - 'CreateSpecialistPoolRequest', - 'DeleteSpecialistPoolRequest', - 'GetSpecialistPoolRequest', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'UpdateSpecialistPoolOperationMetadata', - 'UpdateSpecialistPoolRequest', - 'Measurement', - 'Study', - 'StudySpec', - 'Trial', - 'FilterSplit', - 'FractionSplit', - 'InputDataConfig', - 'PredefinedSplit', - 'TimestampSplit', - 'TrainingPipeline', - 'UserActionReference', - 'AddTrialMeasurementRequest', - 'CheckTrialEarlyStoppingStateMetatdata', - 'CheckTrialEarlyStoppingStateRequest', - 'CheckTrialEarlyStoppingStateResponse', - 'CompleteTrialRequest', - 'CreateStudyRequest', - 'CreateTrialRequest', - 'DeleteStudyRequest', - 'DeleteTrialRequest', - 'GetStudyRequest', - 'GetTrialRequest', - 'ListOptimalTrialsRequest', - 'ListOptimalTrialsResponse', - 'ListStudiesRequest', - 'ListStudiesResponse', - 'ListTrialsRequest', - 'ListTrialsResponse', - 'LookupStudyRequest', - 'StopTrialRequest', - 'SuggestTrialsMetadata', - 'SuggestTrialsRequest', - 'SuggestTrialsResponse', + "AcceleratorType", + "Annotation", + "AnnotationSpec", + "Artifact", + "BatchPredictionJob", + "CompletionStats", + "Context", + "ContainerSpec", + "CustomJob", + "CustomJobSpec", + "PythonPackageSpec", + "Scheduling", + "WorkerPoolSpec", + "DataItem", + "ActiveLearningConfig", + "DataLabelingJob", + "SampleConfig", + "TrainingConfig", + "Dataset", + "ExportDataConfig", + "ImportDataConfig", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "DeleteDatasetRequest", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "GetAnnotationSpecRequest", + "GetDatasetRequest", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "UpdateDatasetRequest", + "DeployedModelRef", + "EncryptionSpec", + "DeployedModel", + "Endpoint", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateEndpointRequest", + "EnvVar", + "Event", + "Execution", + "Attribution", + "Explanation", + "ExplanationMetadataOverride", + "ExplanationParameters", + "ExplanationSpec", + "ExplanationSpecOverride", + "FeatureNoiseSigma", + "IntegratedGradientsAttribution", + "ModelExplanation", + "SampledShapleyAttribution", + "SmoothGradConfig", + "XraiAttribution", + "ExplanationMetadata", + "FeatureStatsAnomaly", + "HyperparameterTuningJob", + "BigQueryDestination", + "BigQuerySource", + "ContainerRegistryDestination", + "GcsDestination", + "GcsSource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "CreateModelDeploymentMonitoringJobRequest", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteHyperparameterTuningJobRequest", + "DeleteModelDeploymentMonitoringJobRequest", + "GetBatchPredictionJobRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetHyperparameterTuningJobRequest", + "GetModelDeploymentMonitoringJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "ListModelDeploymentMonitoringJobsRequest", + "ListModelDeploymentMonitoringJobsResponse", + "PauseModelDeploymentMonitoringJobRequest", + "ResumeModelDeploymentMonitoringJobRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesResponse", + "UpdateModelDeploymentMonitoringJobOperationMetadata", + "UpdateModelDeploymentMonitoringJobRequest", + "JobState", + "LineageSubgraph", + "AutomaticResources", + "AutoscalingMetricSpec", + "BatchDedicatedResources", + "DedicatedResources", + "DiskSpec", + "MachineSpec", + "ResourcesConsumed", + "ManualBatchTuningParameters", + "MetadataSchema", + "AddContextArtifactsAndExecutionsRequest", + "AddContextArtifactsAndExecutionsResponse", + "AddContextChildrenRequest", + "AddContextChildrenResponse", + "AddExecutionEventsRequest", + "AddExecutionEventsResponse", + "CreateArtifactRequest", + "CreateContextRequest", + "CreateExecutionRequest", + "CreateMetadataSchemaRequest", + "CreateMetadataStoreOperationMetadata", + "CreateMetadataStoreRequest", + "DeleteContextRequest", + "DeleteMetadataStoreOperationMetadata", + "DeleteMetadataStoreRequest", + "GetArtifactRequest", + "GetContextRequest", + "GetExecutionRequest", + "GetMetadataSchemaRequest", + "GetMetadataStoreRequest", + "ListArtifactsRequest", + "ListArtifactsResponse", + "ListContextsRequest", + "ListContextsResponse", + "ListExecutionsRequest", + "ListExecutionsResponse", + "ListMetadataSchemasRequest", + "ListMetadataSchemasResponse", + "ListMetadataStoresRequest", + "ListMetadataStoresResponse", + "QueryContextLineageSubgraphRequest", + "QueryExecutionInputsAndOutputsRequest", + "UpdateArtifactRequest", + "UpdateContextRequest", + "UpdateExecutionRequest", + "MetadataStore", + "MigratableResource", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceRequest", + "MigrateResourceResponse", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "Model", + "ModelContainerSpec", + "Port", + "PredictSchemata", + "ModelDeploymentMonitoringBigQueryTable", + "ModelDeploymentMonitoringJob", + "ModelDeploymentMonitoringObjectiveConfig", + "ModelDeploymentMonitoringScheduleConfig", + "ModelMonitoringStatsAnomalies", + "ModelDeploymentMonitoringObjectiveType", + "ModelEvaluation", + "ModelEvaluationSlice", + "ModelMonitoringAlertConfig", + "ModelMonitoringObjectiveConfig", + "SamplingStrategy", + "ThresholdConfig", + "DeleteModelRequest", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "DeleteOperationMetadata", + "GenericOperationMetadata", + "CancelTrainingPipelineRequest", + "CreateTrainingPipelineRequest", + "DeleteTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "PipelineState", + "ExplainRequest", + "ExplainResponse", + "PredictRequest", + "PredictResponse", + "SpecialistPool", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "DeleteSpecialistPoolRequest", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "Measurement", + "Study", + "StudySpec", + "Trial", + "FilterSplit", + "FractionSplit", + "InputDataConfig", + "PredefinedSplit", + "TimestampSplit", + "TrainingPipeline", + "UserActionReference", + "AddTrialMeasurementRequest", + "CheckTrialEarlyStoppingStateMetatdata", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CompleteTrialRequest", + "CreateStudyRequest", + "CreateTrialRequest", + "DeleteStudyRequest", + "DeleteTrialRequest", + "GetStudyRequest", + "GetTrialRequest", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", + "ListStudiesRequest", + "ListStudiesResponse", + "ListTrialsRequest", + "ListTrialsResponse", + "LookupStudyRequest", + "StopTrialRequest", + "SuggestTrialsMetadata", + "SuggestTrialsRequest", + "SuggestTrialsResponse", ) diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py index 65471c7234..8c6968952c 100644 --- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'AcceleratorType', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"AcceleratorType",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index 4b769480a8..a42ef0da82 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Annotation', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Annotation",}, ) @@ -94,22 +91,16 @@ class Annotation(proto.Message): payload_schema_uri = proto.Field(proto.STRING, number=2) - payload = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=8) - annotation_source = proto.Field(proto.MESSAGE, number=5, - message=user_action_reference.UserActionReference, + annotation_source = proto.Field( + proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, ) labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index b60bcebb5f..e921e25971 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'AnnotationSpec', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"AnnotationSpec",}, ) @@ -58,13 +55,9 @@ class AnnotationSpec(proto.Message): display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/artifact.py b/google/cloud/aiplatform_v1beta1/types/artifact.py index b35ae286d7..7d959a3877 100644 --- a/google/cloud/aiplatform_v1beta1/types/artifact.py +++ b/google/cloud/aiplatform_v1beta1/types/artifact.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Artifact', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Artifact",}, ) @@ -107,6 +104,7 @@ class Artifact(proto.Message): description (str): Description of the Artifact """ + class State(proto.Enum): r"""Describes the state of the Artifact.""" STATE_UNSPECIFIED = 0 @@ -123,25 +121,17 @@ class State(proto.Enum): labels = proto.MapField(proto.STRING, proto.STRING, number=10) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - state = proto.Field(proto.ENUM, number=13, - enum=State, - ) + state = proto.Field(proto.ENUM, number=13, enum=State,) schema_title = proto.Field(proto.STRING, number=14) schema_version = proto.Field(proto.STRING, number=15) - metadata = proto.Field(proto.MESSAGE, number=16, - message=struct.Struct, - ) + metadata = proto.Field(proto.MESSAGE, number=16, message=struct.Struct,) description = proto.Field(proto.STRING, number=17) diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index b2bcab9302..69e44eff7f 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -18,23 +18,24 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import completion_stats as gca_completion_stats +from google.cloud.aiplatform_v1beta1.types import ( + completion_stats as gca_completion_stats, +) from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import ( + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, +) from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'BatchPredictionJob', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"BatchPredictionJob",}, ) @@ -190,6 +191,7 @@ class BatchPredictionJob(proto.Message): resources created by the BatchPredictionJob will be encrypted with the provided encryption key. """ + class InputConfig(proto.Message): r"""Configures the input to ``BatchPredictionJob``. @@ -216,12 +218,12 @@ class InputConfig(proto.Message): ``supported_input_storage_formats``. """ - gcs_source = proto.Field(proto.MESSAGE, number=2, oneof='source', - message=io.GcsSource, + gcs_source = proto.Field( + proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, ) - bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', - message=io.BigQuerySource, + bigquery_source = proto.Field( + proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, ) instances_format = proto.Field(proto.STRING, number=1) @@ -291,11 +293,14 @@ class OutputConfig(proto.Message): ``supported_output_storage_formats``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, ) - bigquery_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', + bigquery_destination = proto.Field( + proto.MESSAGE, + number=3, + oneof="destination", message=io.BigQueryDestination, ) @@ -316,9 +321,13 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field(proto.STRING, number=1, oneof='output_location') + gcs_output_directory = proto.Field( + proto.STRING, number=1, oneof="output_location" + ) - bigquery_output_dataset = proto.Field(proto.STRING, number=2, oneof='output_location') + bigquery_output_dataset = proto.Field( + proto.STRING, number=2, oneof="output_location" + ) name = proto.Field(proto.STRING, number=1) @@ -326,76 +335,58 @@ class OutputInfo(proto.Message): model = proto.Field(proto.STRING, number=3) - input_config = proto.Field(proto.MESSAGE, number=4, - message=InputConfig, - ) + input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) - model_parameters = proto.Field(proto.MESSAGE, number=5, - message=struct.Value, - ) + model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - output_config = proto.Field(proto.MESSAGE, number=6, - message=OutputConfig, - ) + output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) - dedicated_resources = proto.Field(proto.MESSAGE, number=7, - message=machine_resources.BatchDedicatedResources, + dedicated_resources = proto.Field( + proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, ) - manual_batch_tuning_parameters = proto.Field(proto.MESSAGE, number=8, + manual_batch_tuning_parameters = proto.Field( + proto.MESSAGE, + number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) generate_explanation = proto.Field(proto.BOOL, number=23) - explanation_spec = proto.Field(proto.MESSAGE, number=25, - message=explanation.ExplanationSpec, + explanation_spec = proto.Field( + proto.MESSAGE, number=25, message=explanation.ExplanationSpec, ) - output_info = proto.Field(proto.MESSAGE, number=9, - message=OutputInfo, - ) + output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) - state = proto.Field(proto.ENUM, number=10, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - error = proto.Field(proto.MESSAGE, number=11, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) - partial_failures = proto.RepeatedField(proto.MESSAGE, number=12, - message=status.Status, + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=12, message=status.Status, ) - resources_consumed = proto.Field(proto.MESSAGE, number=13, - message=machine_resources.ResourcesConsumed, + resources_consumed = proto.Field( + proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, ) - completion_stats = proto.Field(proto.MESSAGE, number=14, - message=gca_completion_stats.CompletionStats, + completion_stats = proto.Field( + proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, ) - create_time = proto.Field(proto.MESSAGE, number=15, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=16, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=17, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=18, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=19) - encryption_spec = proto.Field(proto.MESSAGE, number=24, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/completion_stats.py b/google/cloud/aiplatform_v1beta1/types/completion_stats.py index 3874f412df..165be59634 100644 --- a/google/cloud/aiplatform_v1beta1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/completion_stats.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'CompletionStats', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"CompletionStats",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/context.py b/google/cloud/aiplatform_v1beta1/types/context.py index 59f5289b48..b881e98d0e 100644 --- a/google/cloud/aiplatform_v1beta1/types/context.py +++ b/google/cloud/aiplatform_v1beta1/types/context.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Context', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Context",}, ) @@ -109,13 +106,9 @@ class Context(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=9) - create_time = proto.Field(proto.MESSAGE, number=10, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) parent_contexts = proto.RepeatedField(proto.STRING, number=12) @@ -123,9 +116,7 @@ class Context(proto.Message): schema_version = proto.Field(proto.STRING, number=14) - metadata = proto.Field(proto.MESSAGE, number=15, - message=struct.Struct, - ) + metadata = proto.Field(proto.MESSAGE, number=15, message=struct.Struct,) description = proto.Field(proto.STRING, number=16) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 9de4e3b5fa..1d148b7777 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -28,14 +28,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CustomJob', - 'CustomJobSpec', - 'WorkerPoolSpec', - 'ContainerSpec', - 'PythonPackageSpec', - 'Scheduling', + "CustomJob", + "CustomJobSpec", + "WorkerPoolSpec", + "ContainerSpec", + "PythonPackageSpec", + "Scheduling", }, ) @@ -95,38 +95,24 @@ class CustomJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - job_spec = proto.Field(proto.MESSAGE, number=4, - message='CustomJobSpec', - ) + job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) - state = proto.Field(proto.ENUM, number=5, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=10, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=11) - encryption_spec = proto.Field(proto.MESSAGE, number=12, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=12, message=gca_encryption_spec.EncryptionSpec, ) @@ -191,20 +177,18 @@ class CustomJobSpec(proto.Message): ``//logs/`` """ - worker_pool_specs = proto.RepeatedField(proto.MESSAGE, number=1, - message='WorkerPoolSpec', + worker_pool_specs = proto.RepeatedField( + proto.MESSAGE, number=1, message="WorkerPoolSpec", ) - scheduling = proto.Field(proto.MESSAGE, number=3, - message='Scheduling', - ) + scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) service_account = proto.Field(proto.STRING, number=4) network = proto.Field(proto.STRING, number=5) - base_output_directory = proto.Field(proto.MESSAGE, number=6, - message=io.GcsDestination, + base_output_directory = proto.Field( + proto.MESSAGE, number=6, message=io.GcsDestination, ) @@ -226,22 +210,22 @@ class WorkerPoolSpec(proto.Message): Disk spec. """ - container_spec = proto.Field(proto.MESSAGE, number=6, oneof='task', - message='ContainerSpec', + container_spec = proto.Field( + proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", ) - python_package_spec = proto.Field(proto.MESSAGE, number=7, oneof='task', - message='PythonPackageSpec', + python_package_spec = proto.Field( + proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", ) - machine_spec = proto.Field(proto.MESSAGE, number=1, - message=machine_resources.MachineSpec, + machine_spec = proto.Field( + proto.MESSAGE, number=1, message=machine_resources.MachineSpec, ) replica_count = proto.Field(proto.INT64, number=2) - disk_spec = proto.Field(proto.MESSAGE, number=5, - message=machine_resources.DiskSpec, + disk_spec = proto.Field( + proto.MESSAGE, number=5, message=machine_resources.DiskSpec, ) @@ -318,9 +302,7 @@ class Scheduling(proto.Message): to workers leaving and joining a job. """ - timeout = proto.Field(proto.MESSAGE, number=1, - message=duration.Duration, - ) + timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index 5c50d8e526..a12776f06c 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'DataItem', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"DataItem",}, ) @@ -73,19 +70,13 @@ class DataItem(proto.Message): name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=2, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=3) - payload = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index 0b123cc88e..d750f53e66 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -27,12 +27,12 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'DataLabelingJob', - 'ActiveLearningConfig', - 'SampleConfig', - 'TrainingConfig', + "DataLabelingJob", + "ActiveLearningConfig", + "SampleConfig", + "TrainingConfig", }, ) @@ -154,42 +154,30 @@ class DataLabelingJob(proto.Message): inputs_schema_uri = proto.Field(proto.STRING, number=6) - inputs = proto.Field(proto.MESSAGE, number=7, - message=struct.Value, - ) + inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) - state = proto.Field(proto.ENUM, number=8, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) labeling_progress = proto.Field(proto.INT32, number=13) - current_spend = proto.Field(proto.MESSAGE, number=14, - message=money.Money, - ) + current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) - create_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=10, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=22, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=22, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=11) specialist_pools = proto.RepeatedField(proto.STRING, number=16) - encryption_spec = proto.Field(proto.MESSAGE, number=20, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=20, message=gca_encryption_spec.EncryptionSpec, ) - active_learning_config = proto.Field(proto.MESSAGE, number=21, - message='ActiveLearningConfig', + active_learning_config = proto.Field( + proto.MESSAGE, number=21, message="ActiveLearningConfig", ) @@ -218,18 +206,18 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field(proto.INT64, number=1, oneof='human_labeling_budget') - - max_data_item_percentage = proto.Field(proto.INT32, number=2, oneof='human_labeling_budget') - - sample_config = proto.Field(proto.MESSAGE, number=3, - message='SampleConfig', + max_data_item_count = proto.Field( + proto.INT64, number=1, oneof="human_labeling_budget" ) - training_config = proto.Field(proto.MESSAGE, number=4, - message='TrainingConfig', + max_data_item_percentage = proto.Field( + proto.INT32, number=2, oneof="human_labeling_budget" ) + sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) + + training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) + class SampleConfig(proto.Message): r"""Active learning data sampling config. For every active @@ -249,6 +237,7 @@ class SampleConfig(proto.Message): strategy will decide which data should be selected for human labeling in every batch. """ + class SampleStrategy(proto.Enum): r"""Sample strategy decides which subset of DataItems should be selected for human labeling in every batch. @@ -256,14 +245,16 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field(proto.INT32, number=1, oneof='initial_batch_sample_size') - - following_batch_sample_percentage = proto.Field(proto.INT32, number=3, oneof='following_batch_sample_size') + initial_batch_sample_percentage = proto.Field( + proto.INT32, number=1, oneof="initial_batch_sample_size" + ) - sample_strategy = proto.Field(proto.ENUM, number=5, - enum=SampleStrategy, + following_batch_sample_percentage = proto.Field( + proto.INT32, number=3, oneof="following_batch_sample_size" ) + sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) + class TrainingConfig(proto.Message): r"""CMLE training config. For every active learning labeling diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 969596f706..9fa17fcb3a 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -25,12 +25,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Dataset', - 'ImportDataConfig', - 'ExportDataConfig', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"Dataset", "ImportDataConfig", "ExportDataConfig",}, ) @@ -98,24 +94,18 @@ class Dataset(proto.Message): metadata_schema_uri = proto.Field(proto.STRING, number=3) - metadata = proto.Field(proto.MESSAGE, number=8, - message=struct.Value, - ) + metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=6) labels = proto.MapField(proto.STRING, proto.STRING, number=7) - encryption_spec = proto.Field(proto.MESSAGE, number=11, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=11, message=gca_encryption_spec.EncryptionSpec, ) @@ -151,8 +141,8 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field(proto.MESSAGE, number=1, oneof='source', - message=io.GcsSource, + gcs_source = proto.Field( + proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, ) data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) @@ -185,8 +175,8 @@ class ExportDataConfig(proto.Message): ``ListAnnotations``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, ) annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 73b9b56d5a..1ab94b8c89 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -26,26 +26,26 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateDatasetRequest', - 'CreateDatasetOperationMetadata', - 'GetDatasetRequest', - 'UpdateDatasetRequest', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'DeleteDatasetRequest', - 'ImportDataRequest', - 'ImportDataResponse', - 'ImportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportDataOperationMetadata', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'GetAnnotationSpecRequest', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', + "CreateDatasetRequest", + "CreateDatasetOperationMetadata", + "GetDatasetRequest", + "UpdateDatasetRequest", + "ListDatasetsRequest", + "ListDatasetsResponse", + "DeleteDatasetRequest", + "ImportDataRequest", + "ImportDataResponse", + "ImportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportDataOperationMetadata", + "ListDataItemsRequest", + "ListDataItemsResponse", + "GetAnnotationSpecRequest", + "ListAnnotationsRequest", + "ListAnnotationsResponse", }, ) @@ -65,9 +65,7 @@ class CreateDatasetRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - dataset = proto.Field(proto.MESSAGE, number=2, - message=gca_dataset.Dataset, - ) + dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) class CreateDatasetOperationMetadata(proto.Message): @@ -79,8 +77,8 @@ class CreateDatasetOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -97,9 +95,7 @@ class GetDatasetRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateDatasetRequest(proto.Message): @@ -121,13 +117,9 @@ class UpdateDatasetRequest(proto.Message): - ``labels`` """ - dataset = proto.Field(proto.MESSAGE, number=1, - message=gca_dataset.Dataset, - ) + dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class ListDatasetsRequest(proto.Message): @@ -179,9 +171,7 @@ class ListDatasetsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -202,8 +192,8 @@ class ListDatasetsResponse(proto.Message): def raw_page(self): return self - datasets = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_dataset.Dataset, + datasets = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_dataset.Dataset, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -239,8 +229,8 @@ class ImportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - import_configs = proto.RepeatedField(proto.MESSAGE, number=2, - message=gca_dataset.ImportDataConfig, + import_configs = proto.RepeatedField( + proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, ) @@ -259,8 +249,8 @@ class ImportDataOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -278,8 +268,8 @@ class ExportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - export_config = proto.Field(proto.MESSAGE, number=2, - message=gca_dataset.ExportDataConfig, + export_config = proto.Field( + proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, ) @@ -309,8 +299,8 @@ class ExportDataOperationMetadata(proto.Message): the directory. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -347,9 +337,7 @@ class ListDataItemsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -370,8 +358,8 @@ class ListDataItemsResponse(proto.Message): def raw_page(self): return self - data_items = proto.RepeatedField(proto.MESSAGE, number=1, - message=data_item.DataItem, + data_items = proto.RepeatedField( + proto.MESSAGE, number=1, message=data_item.DataItem, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -391,9 +379,7 @@ class GetAnnotationSpecRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class ListAnnotationsRequest(proto.Message): @@ -427,9 +413,7 @@ class ListAnnotationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -450,8 +434,8 @@ class ListAnnotationsResponse(proto.Message): def raw_page(self): return self - annotations = proto.RepeatedField(proto.MESSAGE, number=1, - message=annotation.Annotation, + annotations = proto.RepeatedField( + proto.MESSAGE, number=1, message=annotation.Annotation, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py index aa5c8424aa..b0ec7010a2 100644 --- a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'DeployedModelRef', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"DeployedModelRef",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/encryption_spec.py b/google/cloud/aiplatform_v1beta1/types/encryption_spec.py index 398d935aa4..0d41d39a0b 100644 --- a/google/cloud/aiplatform_v1beta1/types/encryption_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/encryption_spec.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'EncryptionSpec', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"EncryptionSpec",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 85393de4b8..40ede068f3 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -25,11 +25,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Endpoint', - 'DeployedModel', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Endpoint", "DeployedModel",}, ) @@ -97,8 +93,8 @@ class Endpoint(proto.Message): description = proto.Field(proto.STRING, number=3) - deployed_models = proto.RepeatedField(proto.MESSAGE, number=4, - message='DeployedModel', + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=4, message="DeployedModel", ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) @@ -107,16 +103,12 @@ class Endpoint(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=7) - create_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - encryption_spec = proto.Field(proto.MESSAGE, number=10, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=10, message=gca_encryption_spec.EncryptionSpec, ) @@ -192,11 +184,17 @@ class DeployedModel(proto.Message): option. """ - dedicated_resources = proto.Field(proto.MESSAGE, number=7, oneof='prediction_resources', + dedicated_resources = proto.Field( + proto.MESSAGE, + number=7, + oneof="prediction_resources", message=machine_resources.DedicatedResources, ) - automatic_resources = proto.Field(proto.MESSAGE, number=8, oneof='prediction_resources', + automatic_resources = proto.Field( + proto.MESSAGE, + number=8, + oneof="prediction_resources", message=machine_resources.AutomaticResources, ) @@ -206,12 +204,10 @@ class DeployedModel(proto.Message): display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - explanation_spec = proto.Field(proto.MESSAGE, number=9, - message=explanation.ExplanationSpec, + explanation_spec = proto.Field( + proto.MESSAGE, number=9, message=explanation.ExplanationSpec, ) service_account = proto.Field(proto.STRING, number=11) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index 9fa5944c5f..fe7442ab2a 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -24,21 +24,21 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateEndpointRequest', - 'CreateEndpointOperationMetadata', - 'GetEndpointRequest', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'UpdateEndpointRequest', - 'DeleteEndpointRequest', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UndeployModelOperationMetadata', + "CreateEndpointRequest", + "CreateEndpointOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UpdateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelRequest", + "DeployModelResponse", + "DeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UndeployModelOperationMetadata", }, ) @@ -58,9 +58,7 @@ class CreateEndpointRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - endpoint = proto.Field(proto.MESSAGE, number=2, - message=gca_endpoint.Endpoint, - ) + endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) class CreateEndpointOperationMetadata(proto.Message): @@ -72,8 +70,8 @@ class CreateEndpointOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -143,9 +141,7 @@ class ListEndpointsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListEndpointsResponse(proto.Message): @@ -165,8 +161,8 @@ class ListEndpointsResponse(proto.Message): def raw_page(self): return self - endpoints = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_endpoint.Endpoint, + endpoints = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -185,13 +181,9 @@ class UpdateEndpointRequest(proto.Message): `FieldMask `__. """ - endpoint = proto.Field(proto.MESSAGE, number=1, - message=gca_endpoint.Endpoint, - ) + endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteEndpointRequest(proto.Message): @@ -244,8 +236,8 @@ class DeployModelRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - deployed_model = proto.Field(proto.MESSAGE, number=2, - message=gca_endpoint.DeployedModel, + deployed_model = proto.Field( + proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -261,8 +253,8 @@ class DeployModelResponse(proto.Message): the Endpoint. """ - deployed_model = proto.Field(proto.MESSAGE, number=1, - message=gca_endpoint.DeployedModel, + deployed_model = proto.Field( + proto.MESSAGE, number=1, message=gca_endpoint.DeployedModel, ) @@ -275,8 +267,8 @@ class DeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -325,8 +317,8 @@ class UndeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/env_var.py b/google/cloud/aiplatform_v1beta1/types/env_var.py index 1e1f279843..0d2c3769ff 100644 --- a/google/cloud/aiplatform_v1beta1/types/env_var.py +++ b/google/cloud/aiplatform_v1beta1/types/env_var.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'EnvVar', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"EnvVar",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/event.py b/google/cloud/aiplatform_v1beta1/types/event.py index fedaf1e205..52bf55e074 100644 --- a/google/cloud/aiplatform_v1beta1/types/event.py +++ b/google/cloud/aiplatform_v1beta1/types/event.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Event', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Event",}, ) @@ -60,6 +57,7 @@ class Event(proto.Message): keys are prefixed with "aiplatform.googleapis.com/" and are immutable. """ + class Type(proto.Enum): r"""Describes whether an Event's Artifact is the Execution's input or output. @@ -72,13 +70,9 @@ class Type(proto.Enum): execution = proto.Field(proto.STRING, number=2) - event_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + event_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - type_ = proto.Field(proto.ENUM, number=4, - enum=Type, - ) + type_ = proto.Field(proto.ENUM, number=4, enum=Type,) labels = proto.MapField(proto.STRING, proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/execution.py b/google/cloud/aiplatform_v1beta1/types/execution.py index f252dc1def..d600e9a346 100644 --- a/google/cloud/aiplatform_v1beta1/types/execution.py +++ b/google/cloud/aiplatform_v1beta1/types/execution.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Execution', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Execution",}, ) @@ -103,6 +100,7 @@ class Execution(proto.Message): description (str): Description of the Execution """ + class State(proto.Enum): r"""Describes the state of the Execution.""" STATE_UNSPECIFIED = 0 @@ -115,29 +113,21 @@ class State(proto.Enum): display_name = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=6, - enum=State, - ) + state = proto.Field(proto.ENUM, number=6, enum=State,) etag = proto.Field(proto.STRING, number=9) labels = proto.MapField(proto.STRING, proto.STRING, number=10) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) schema_title = proto.Field(proto.STRING, number=13) schema_version = proto.Field(proto.STRING, number=14) - metadata = proto.Field(proto.MESSAGE, number=15, - message=struct.Struct, - ) + metadata = proto.Field(proto.MESSAGE, number=15, message=struct.Struct,) description = proto.Field(proto.STRING, number=16) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index e7980559cc..d9b48b60ab 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -23,20 +23,20 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'Explanation', - 'ModelExplanation', - 'Attribution', - 'ExplanationSpec', - 'ExplanationParameters', - 'SampledShapleyAttribution', - 'IntegratedGradientsAttribution', - 'XraiAttribution', - 'SmoothGradConfig', - 'FeatureNoiseSigma', - 'ExplanationSpecOverride', - 'ExplanationMetadataOverride', + "Explanation", + "ModelExplanation", + "Attribution", + "ExplanationSpec", + "ExplanationParameters", + "SampledShapleyAttribution", + "IntegratedGradientsAttribution", + "XraiAttribution", + "SmoothGradConfig", + "FeatureNoiseSigma", + "ExplanationSpecOverride", + "ExplanationMetadataOverride", }, ) @@ -73,9 +73,7 @@ class Explanation(proto.Message): in the same order as they appear in the output_indices. """ - attributions = proto.RepeatedField(proto.MESSAGE, number=1, - message='Attribution', - ) + attributions = proto.RepeatedField(proto.MESSAGE, number=1, message="Attribution",) class ModelExplanation(proto.Message): @@ -112,8 +110,8 @@ class ModelExplanation(proto.Message): is not populated. """ - mean_attributions = proto.RepeatedField(proto.MESSAGE, number=1, - message='Attribution', + mean_attributions = proto.RepeatedField( + proto.MESSAGE, number=1, message="Attribution", ) @@ -237,9 +235,7 @@ class Attribution(proto.Message): instance_output_value = proto.Field(proto.DOUBLE, number=2) - feature_attributions = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + feature_attributions = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) output_index = proto.RepeatedField(proto.INT32, number=4) @@ -262,12 +258,10 @@ class ExplanationSpec(proto.Message): input and output for explanation. """ - parameters = proto.Field(proto.MESSAGE, number=1, - message='ExplanationParameters', - ) + parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) - metadata = proto.Field(proto.MESSAGE, number=2, - message=explanation_metadata.ExplanationMetadata, + metadata = proto.Field( + proto.MESSAGE, number=2, message=explanation_metadata.ExplanationMetadata, ) @@ -324,23 +318,24 @@ class ExplanationParameters(proto.Message): (e,g, multi-class Models that predict multiple classes). """ - sampled_shapley_attribution = proto.Field(proto.MESSAGE, number=1, oneof='method', - message='SampledShapleyAttribution', + sampled_shapley_attribution = proto.Field( + proto.MESSAGE, number=1, oneof="method", message="SampledShapleyAttribution", ) - integrated_gradients_attribution = proto.Field(proto.MESSAGE, number=2, oneof='method', - message='IntegratedGradientsAttribution', + integrated_gradients_attribution = proto.Field( + proto.MESSAGE, + number=2, + oneof="method", + message="IntegratedGradientsAttribution", ) - xrai_attribution = proto.Field(proto.MESSAGE, number=3, oneof='method', - message='XraiAttribution', + xrai_attribution = proto.Field( + proto.MESSAGE, number=3, oneof="method", message="XraiAttribution", ) top_k = proto.Field(proto.INT32, number=4) - output_indices = proto.Field(proto.MESSAGE, number=5, - message=struct.ListValue, - ) + output_indices = proto.Field(proto.MESSAGE, number=5, message=struct.ListValue,) class SampledShapleyAttribution(proto.Message): @@ -387,8 +382,8 @@ class IntegratedGradientsAttribution(proto.Message): step_count = proto.Field(proto.INT32, number=1) - smooth_grad_config = proto.Field(proto.MESSAGE, number=2, - message='SmoothGradConfig', + smooth_grad_config = proto.Field( + proto.MESSAGE, number=2, message="SmoothGradConfig", ) @@ -421,8 +416,8 @@ class XraiAttribution(proto.Message): step_count = proto.Field(proto.INT32, number=1) - smooth_grad_config = proto.Field(proto.MESSAGE, number=2, - message='SmoothGradConfig', + smooth_grad_config = proto.Field( + proto.MESSAGE, number=2, message="SmoothGradConfig", ) @@ -467,10 +462,13 @@ class SmoothGradConfig(proto.Message): Valid range of its value is [1, 50]. Defaults to 3. """ - noise_sigma = proto.Field(proto.FLOAT, number=1, oneof='GradientNoiseSigma') + noise_sigma = proto.Field(proto.FLOAT, number=1, oneof="GradientNoiseSigma") - feature_noise_sigma = proto.Field(proto.MESSAGE, number=2, oneof='GradientNoiseSigma', - message='FeatureNoiseSigma', + feature_noise_sigma = proto.Field( + proto.MESSAGE, + number=2, + oneof="GradientNoiseSigma", + message="FeatureNoiseSigma", ) noisy_sample_count = proto.Field(proto.INT32, number=3) @@ -486,6 +484,7 @@ class FeatureNoiseSigma(proto.Message): Noise sigma per feature. No noise is added to features that are not set. """ + class NoiseSigmaForFeature(proto.Message): r"""Noise sigma for a single feature. @@ -507,8 +506,8 @@ class NoiseSigmaForFeature(proto.Message): sigma = proto.Field(proto.FLOAT, number=2) - noise_sigma = proto.RepeatedField(proto.MESSAGE, number=1, - message=NoiseSigmaForFeature, + noise_sigma = proto.RepeatedField( + proto.MESSAGE, number=1, message=NoiseSigmaForFeature, ) @@ -530,12 +529,10 @@ class ExplanationSpecOverride(proto.Message): specified, no metadata is overridden. """ - parameters = proto.Field(proto.MESSAGE, number=1, - message='ExplanationParameters', - ) + parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) - metadata = proto.Field(proto.MESSAGE, number=2, - message='ExplanationMetadataOverride', + metadata = proto.Field( + proto.MESSAGE, number=2, message="ExplanationMetadataOverride", ) @@ -556,6 +553,7 @@ class ExplanationMetadataOverride(proto.Message): here, the corresponding feature's input metadata is not overridden. """ + class InputMetadataOverride(proto.Message): r"""The [input metadata][google.cloud.aiplatform.v1beta1.ExplanationMetadata.InputMetadata] @@ -572,12 +570,12 @@ class InputMetadataOverride(proto.Message): overridden. """ - input_baselines = proto.RepeatedField(proto.MESSAGE, number=1, - message=struct.Value, + input_baselines = proto.RepeatedField( + proto.MESSAGE, number=1, message=struct.Value, ) - inputs = proto.MapField(proto.STRING, proto.MESSAGE, number=1, - message=InputMetadataOverride, + inputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=1, message=InputMetadataOverride, ) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 79cb0925c4..69947e9b9e 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ExplanationMetadata', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"ExplanationMetadata",}, ) @@ -73,6 +70,7 @@ class ExplanationMetadata(proto.Message): output URI will point to a location where the user only has a read access. """ + class InputMetadata(proto.Message): r"""Metadata of the input of a feature. @@ -160,6 +158,7 @@ class InputMetadata(proto.Message): featureAttributions][Attribution.feature_attributions], keyed by the group name. """ + class Encoding(proto.Enum): r"""Defines how the feature is encoded to [encoded_tensor][]. Defaults to IDENTITY. @@ -251,6 +250,7 @@ class Visualization(proto.Message): makes it difficult to view the visualization. Defaults to NONE. """ + class Type(proto.Enum): r"""Type of the image visualization. Only applicable to [Integrated Gradients attribution] @@ -287,40 +287,50 @@ class OverlayType(proto.Enum): GRAYSCALE = 3 MASK_BLACK = 4 - type_ = proto.Field(proto.ENUM, number=1, - enum='ExplanationMetadata.InputMetadata.Visualization.Type', + type_ = proto.Field( + proto.ENUM, + number=1, + enum="ExplanationMetadata.InputMetadata.Visualization.Type", ) - polarity = proto.Field(proto.ENUM, number=2, - enum='ExplanationMetadata.InputMetadata.Visualization.Polarity', + polarity = proto.Field( + proto.ENUM, + number=2, + enum="ExplanationMetadata.InputMetadata.Visualization.Polarity", ) - color_map = proto.Field(proto.ENUM, number=3, - enum='ExplanationMetadata.InputMetadata.Visualization.ColorMap', + color_map = proto.Field( + proto.ENUM, + number=3, + enum="ExplanationMetadata.InputMetadata.Visualization.ColorMap", ) clip_percent_upperbound = proto.Field(proto.FLOAT, number=4) clip_percent_lowerbound = proto.Field(proto.FLOAT, number=5) - overlay_type = proto.Field(proto.ENUM, number=6, - enum='ExplanationMetadata.InputMetadata.Visualization.OverlayType', + overlay_type = proto.Field( + proto.ENUM, + number=6, + enum="ExplanationMetadata.InputMetadata.Visualization.OverlayType", ) - input_baselines = proto.RepeatedField(proto.MESSAGE, number=1, - message=struct.Value, + input_baselines = proto.RepeatedField( + proto.MESSAGE, number=1, message=struct.Value, ) input_tensor_name = proto.Field(proto.STRING, number=2) - encoding = proto.Field(proto.ENUM, number=3, - enum='ExplanationMetadata.InputMetadata.Encoding', + encoding = proto.Field( + proto.ENUM, number=3, enum="ExplanationMetadata.InputMetadata.Encoding", ) modality = proto.Field(proto.STRING, number=4) - feature_value_domain = proto.Field(proto.MESSAGE, number=5, - message='ExplanationMetadata.InputMetadata.FeatureValueDomain', + feature_value_domain = proto.Field( + proto.MESSAGE, + number=5, + message="ExplanationMetadata.InputMetadata.FeatureValueDomain", ) indices_tensor_name = proto.Field(proto.STRING, number=6) @@ -331,12 +341,14 @@ class OverlayType(proto.Enum): encoded_tensor_name = proto.Field(proto.STRING, number=9) - encoded_baselines = proto.RepeatedField(proto.MESSAGE, number=10, - message=struct.Value, + encoded_baselines = proto.RepeatedField( + proto.MESSAGE, number=10, message=struct.Value, ) - visualization = proto.Field(proto.MESSAGE, number=11, - message='ExplanationMetadata.InputMetadata.Visualization', + visualization = proto.Field( + proto.MESSAGE, + number=11, + message="ExplanationMetadata.InputMetadata.Visualization", ) group_name = proto.Field(proto.STRING, number=12) @@ -378,20 +390,22 @@ class OutputMetadata(proto.Message): for Tensorflow. """ - index_display_name_mapping = proto.Field(proto.MESSAGE, number=1, oneof='display_name_mapping', - message=struct.Value, + index_display_name_mapping = proto.Field( + proto.MESSAGE, number=1, oneof="display_name_mapping", message=struct.Value, ) - display_name_mapping_key = proto.Field(proto.STRING, number=2, oneof='display_name_mapping') + display_name_mapping_key = proto.Field( + proto.STRING, number=2, oneof="display_name_mapping" + ) output_tensor_name = proto.Field(proto.STRING, number=3) - inputs = proto.MapField(proto.STRING, proto.MESSAGE, number=1, - message=InputMetadata, + inputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=1, message=InputMetadata, ) - outputs = proto.MapField(proto.STRING, proto.MESSAGE, number=2, - message=OutputMetadata, + outputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=2, message=OutputMetadata, ) feature_attributions_schema_uri = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py b/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py index 4dce0cda0e..a0c6d51e0f 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'FeatureStatsAnomaly', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"FeatureStatsAnomaly",}, ) @@ -96,13 +93,9 @@ class FeatureStatsAnomaly(proto.Message): anomaly_detection_threshold = proto.Field(proto.DOUBLE, number=9) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py index fbf5262553..55978a409e 100644 --- a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py @@ -27,10 +27,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'HyperparameterTuningJob', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"HyperparameterTuningJob",}, ) @@ -109,9 +106,7 @@ class HyperparameterTuningJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=4, - message=study.StudySpec, - ) + study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) max_trial_count = proto.Field(proto.INT32, number=5) @@ -119,42 +114,28 @@ class HyperparameterTuningJob(proto.Message): max_failed_trial_count = proto.Field(proto.INT32, number=7) - trial_job_spec = proto.Field(proto.MESSAGE, number=8, - message=custom_job.CustomJobSpec, + trial_job_spec = proto.Field( + proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, ) - trials = proto.RepeatedField(proto.MESSAGE, number=9, - message=study.Trial, - ) + trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) - state = proto.Field(proto.ENUM, number=10, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=15, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=16) - encryption_spec = proto.Field(proto.MESSAGE, number=17, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=17, message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index 0d938b4628..3a177dcf9b 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -19,13 +19,13 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'GcsSource', - 'GcsDestination', - 'BigQuerySource', - 'BigQueryDestination', - 'ContainerRegistryDestination', + "GcsSource", + "GcsDestination", + "BigQuerySource", + "BigQueryDestination", + "ContainerRegistryDestination", }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index bc8b117832..7fd85c81b2 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -18,54 +18,62 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.cloud.aiplatform_v1beta1.types import operation from google.protobuf import field_mask_pb2 as field_mask # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateCustomJobRequest', - 'GetCustomJobRequest', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'DeleteCustomJobRequest', - 'CancelCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'GetDataLabelingJobRequest', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'DeleteDataLabelingJobRequest', - 'CancelDataLabelingJobRequest', - 'CreateHyperparameterTuningJobRequest', - 'GetHyperparameterTuningJobRequest', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'DeleteHyperparameterTuningJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CreateBatchPredictionJobRequest', - 'GetBatchPredictionJobRequest', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'DeleteBatchPredictionJobRequest', - 'CancelBatchPredictionJobRequest', - 'CreateModelDeploymentMonitoringJobRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', - 'GetModelDeploymentMonitoringJobRequest', - 'ListModelDeploymentMonitoringJobsRequest', - 'ListModelDeploymentMonitoringJobsResponse', - 'UpdateModelDeploymentMonitoringJobRequest', - 'DeleteModelDeploymentMonitoringJobRequest', - 'PauseModelDeploymentMonitoringJobRequest', - 'ResumeModelDeploymentMonitoringJobRequest', - 'UpdateModelDeploymentMonitoringJobOperationMetadata', + "CreateCustomJobRequest", + "GetCustomJobRequest", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "DeleteCustomJobRequest", + "CancelCustomJobRequest", + "CreateDataLabelingJobRequest", + "GetDataLabelingJobRequest", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "DeleteDataLabelingJobRequest", + "CancelDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "GetHyperparameterTuningJobRequest", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "DeleteHyperparameterTuningJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "GetBatchPredictionJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "DeleteBatchPredictionJobRequest", + "CancelBatchPredictionJobRequest", + "CreateModelDeploymentMonitoringJobRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesResponse", + "GetModelDeploymentMonitoringJobRequest", + "ListModelDeploymentMonitoringJobsRequest", + "ListModelDeploymentMonitoringJobsResponse", + "UpdateModelDeploymentMonitoringJobRequest", + "DeleteModelDeploymentMonitoringJobRequest", + "PauseModelDeploymentMonitoringJobRequest", + "ResumeModelDeploymentMonitoringJobRequest", + "UpdateModelDeploymentMonitoringJobOperationMetadata", }, ) @@ -85,9 +93,7 @@ class CreateCustomJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - custom_job = proto.Field(proto.MESSAGE, number=2, - message=gca_custom_job.CustomJob, - ) + custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) class GetCustomJobRequest(proto.Message): @@ -150,9 +156,7 @@ class ListCustomJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListCustomJobsResponse(proto.Message): @@ -172,8 +176,8 @@ class ListCustomJobsResponse(proto.Message): def raw_page(self): return self - custom_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_custom_job.CustomJob, + custom_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -220,8 +224,8 @@ class CreateDataLabelingJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field(proto.MESSAGE, number=2, - message=gca_data_labeling_job.DataLabelingJob, + data_labeling_job = proto.Field( + proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, ) @@ -286,9 +290,7 @@ class ListDataLabelingJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -309,8 +311,8 @@ class ListDataLabelingJobsResponse(proto.Message): def raw_page(self): return self - data_labeling_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_data_labeling_job.DataLabelingJob, + data_labeling_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -359,7 +361,9 @@ class CreateHyperparameterTuningJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - hyperparameter_tuning_job = proto.Field(proto.MESSAGE, number=2, + hyperparameter_tuning_job = proto.Field( + proto.MESSAGE, + number=2, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -425,9 +429,7 @@ class ListHyperparameterTuningJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListHyperparameterTuningJobsResponse(proto.Message): @@ -449,7 +451,9 @@ class ListHyperparameterTuningJobsResponse(proto.Message): def raw_page(self): return self - hyperparameter_tuning_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + hyperparameter_tuning_jobs = proto.RepeatedField( + proto.MESSAGE, + number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -499,8 +503,8 @@ class CreateBatchPredictionJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - batch_prediction_job = proto.Field(proto.MESSAGE, number=2, - message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_job = proto.Field( + proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -567,9 +571,7 @@ class ListBatchPredictionJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListBatchPredictionJobsResponse(proto.Message): @@ -590,8 +592,8 @@ class ListBatchPredictionJobsResponse(proto.Message): def raw_page(self): return self - batch_prediction_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -640,7 +642,9 @@ class CreateModelDeploymentMonitoringJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - model_deployment_monitoring_job = proto.Field(proto.MESSAGE, number=2, + model_deployment_monitoring_job = proto.Field( + proto.MESSAGE, + number=2, message=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, ) @@ -680,6 +684,7 @@ class SearchModelDeploymentMonitoringStatsAnomaliesRequest(proto.Message): generated. If not set, indicates feching stats till the latest possible one. """ + class StatsAnomaliesObjective(proto.Message): r"""Stats requested for specific objective. @@ -697,7 +702,9 @@ class StatsAnomaliesObjective(proto.Message): latest monitoring run. """ - type_ = proto.Field(proto.ENUM, number=1, + type_ = proto.Field( + proto.ENUM, + number=1, enum=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringObjectiveType, ) @@ -709,21 +716,17 @@ class StatsAnomaliesObjective(proto.Message): feature_display_name = proto.Field(proto.STRING, number=3) - objectives = proto.RepeatedField(proto.MESSAGE, number=4, - message=StatsAnomaliesObjective, + objectives = proto.RepeatedField( + proto.MESSAGE, number=4, message=StatsAnomaliesObjective, ) page_size = proto.Field(proto.INT32, number=5) page_token = proto.Field(proto.STRING, number=6) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) class SearchModelDeploymentMonitoringStatsAnomaliesResponse(proto.Message): @@ -746,7 +749,9 @@ class SearchModelDeploymentMonitoringStatsAnomaliesResponse(proto.Message): def raw_page(self): return self - monitoring_stats = proto.RepeatedField(proto.MESSAGE, number=1, + monitoring_stats = proto.RepeatedField( + proto.MESSAGE, + number=1, message=gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies, ) @@ -793,9 +798,7 @@ class ListModelDeploymentMonitoringJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelDeploymentMonitoringJobsResponse(proto.Message): @@ -814,7 +817,9 @@ class ListModelDeploymentMonitoringJobsResponse(proto.Message): def raw_page(self): return self - model_deployment_monitoring_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + model_deployment_monitoring_jobs = proto.RepeatedField( + proto.MESSAGE, + number=1, message=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, ) @@ -834,13 +839,13 @@ class UpdateModelDeploymentMonitoringJobRequest(proto.Message): resource. """ - model_deployment_monitoring_job = proto.Field(proto.MESSAGE, number=1, + model_deployment_monitoring_job = proto.Field( + proto.MESSAGE, + number=1, message=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, ) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteModelDeploymentMonitoringJobRequest(proto.Message): @@ -894,8 +899,8 @@ class UpdateModelDeploymentMonitoringJobOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/job_state.py b/google/cloud/aiplatform_v1beta1/types/job_state.py index 6d199390db..b77947cc9a 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_state.py +++ b/google/cloud/aiplatform_v1beta1/types/job_state.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'JobState', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"JobState",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py b/google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py index ba291eb8f6..f4ff6b2d97 100644 --- a/google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py +++ b/google/cloud/aiplatform_v1beta1/types/lineage_subgraph.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'LineageSubgraph', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"LineageSubgraph",}, ) @@ -45,17 +42,13 @@ class LineageSubgraph(proto.Message): Executions in the subgraph. """ - artifacts = proto.RepeatedField(proto.MESSAGE, number=1, - message=artifact.Artifact, - ) + artifacts = proto.RepeatedField(proto.MESSAGE, number=1, message=artifact.Artifact,) - executions = proto.RepeatedField(proto.MESSAGE, number=2, - message=execution.Execution, + executions = proto.RepeatedField( + proto.MESSAGE, number=2, message=execution.Execution, ) - events = proto.RepeatedField(proto.MESSAGE, number=3, - message=event.Event, - ) + events = proto.RepeatedField(proto.MESSAGE, number=3, message=event.Event,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index 48b2ad18c4..c791354c58 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -18,19 +18,21 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import accelerator_type as gca_accelerator_type +from google.cloud.aiplatform_v1beta1.types import ( + accelerator_type as gca_accelerator_type, +) __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'MachineSpec', - 'DedicatedResources', - 'AutomaticResources', - 'BatchDedicatedResources', - 'ResourcesConsumed', - 'DiskSpec', - 'AutoscalingMetricSpec', + "MachineSpec", + "DedicatedResources", + "AutomaticResources", + "BatchDedicatedResources", + "ResourcesConsumed", + "DiskSpec", + "AutoscalingMetricSpec", }, ) @@ -65,8 +67,8 @@ class MachineSpec(proto.Message): machine_type = proto.Field(proto.STRING, number=1) - accelerator_type = proto.Field(proto.ENUM, number=2, - enum=gca_accelerator_type.AcceleratorType, + accelerator_type = proto.Field( + proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, ) accelerator_count = proto.Field(proto.INT32, number=3) @@ -133,16 +135,14 @@ class DedicatedResources(proto.Message): to ``80``. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, - message='MachineSpec', - ) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) min_replica_count = proto.Field(proto.INT32, number=2) max_replica_count = proto.Field(proto.INT32, number=3) - autoscaling_metric_specs = proto.RepeatedField(proto.MESSAGE, number=4, - message='AutoscalingMetricSpec', + autoscaling_metric_specs = proto.RepeatedField( + proto.MESSAGE, number=4, message="AutoscalingMetricSpec", ) @@ -203,9 +203,7 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, - message='MachineSpec', - ) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) starting_replica_count = proto.Field(proto.INT32, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py index da5c4d38ab..7a467d5069 100644 --- a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py +++ b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ManualBatchTuningParameters', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"ManualBatchTuningParameters",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_schema.py b/google/cloud/aiplatform_v1beta1/types/metadata_schema.py index 7c690a1b94..d2c6f97fa8 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_schema.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_schema.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'MetadataSchema', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"MetadataSchema",}, ) @@ -60,6 +57,7 @@ class MetadataSchema(proto.Message): description (str): Description of the Metadata Schema """ + class MetadataSchemaType(proto.Enum): r"""Describes the type of the MetadataSchema.""" METADATA_SCHEMA_TYPE_UNSPECIFIED = 0 @@ -73,13 +71,9 @@ class MetadataSchemaType(proto.Enum): schema = proto.Field(proto.STRING, number=3) - schema_type = proto.Field(proto.ENUM, number=4, - enum=MetadataSchemaType, - ) + schema_type = proto.Field(proto.ENUM, number=4, enum=MetadataSchemaType,) - create_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) description = proto.Field(proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_service.py b/google/cloud/aiplatform_v1beta1/types/metadata_service.py index 3777316237..8dfe2682ee 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_service.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_service.py @@ -29,43 +29,43 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateMetadataStoreRequest', - 'CreateMetadataStoreOperationMetadata', - 'GetMetadataStoreRequest', - 'ListMetadataStoresRequest', - 'ListMetadataStoresResponse', - 'DeleteMetadataStoreRequest', - 'DeleteMetadataStoreOperationMetadata', - 'CreateArtifactRequest', - 'GetArtifactRequest', - 'ListArtifactsRequest', - 'ListArtifactsResponse', - 'UpdateArtifactRequest', - 'CreateContextRequest', - 'GetContextRequest', - 'ListContextsRequest', - 'ListContextsResponse', - 'UpdateContextRequest', - 'DeleteContextRequest', - 'AddContextArtifactsAndExecutionsRequest', - 'AddContextArtifactsAndExecutionsResponse', - 'AddContextChildrenRequest', - 'AddContextChildrenResponse', - 'QueryContextLineageSubgraphRequest', - 'CreateExecutionRequest', - 'GetExecutionRequest', - 'ListExecutionsRequest', - 'ListExecutionsResponse', - 'UpdateExecutionRequest', - 'AddExecutionEventsRequest', - 'AddExecutionEventsResponse', - 'QueryExecutionInputsAndOutputsRequest', - 'CreateMetadataSchemaRequest', - 'GetMetadataSchemaRequest', - 'ListMetadataSchemasRequest', - 'ListMetadataSchemasResponse', + "CreateMetadataStoreRequest", + "CreateMetadataStoreOperationMetadata", + "GetMetadataStoreRequest", + "ListMetadataStoresRequest", + "ListMetadataStoresResponse", + "DeleteMetadataStoreRequest", + "DeleteMetadataStoreOperationMetadata", + "CreateArtifactRequest", + "GetArtifactRequest", + "ListArtifactsRequest", + "ListArtifactsResponse", + "UpdateArtifactRequest", + "CreateContextRequest", + "GetContextRequest", + "ListContextsRequest", + "ListContextsResponse", + "UpdateContextRequest", + "DeleteContextRequest", + "AddContextArtifactsAndExecutionsRequest", + "AddContextArtifactsAndExecutionsResponse", + "AddContextChildrenRequest", + "AddContextChildrenResponse", + "QueryContextLineageSubgraphRequest", + "CreateExecutionRequest", + "GetExecutionRequest", + "ListExecutionsRequest", + "ListExecutionsResponse", + "UpdateExecutionRequest", + "AddExecutionEventsRequest", + "AddExecutionEventsResponse", + "QueryExecutionInputsAndOutputsRequest", + "CreateMetadataSchemaRequest", + "GetMetadataSchemaRequest", + "ListMetadataSchemasRequest", + "ListMetadataSchemasResponse", }, ) @@ -96,8 +96,8 @@ class CreateMetadataStoreRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - metadata_store = proto.Field(proto.MESSAGE, number=2, - message=gca_metadata_store.MetadataStore, + metadata_store = proto.Field( + proto.MESSAGE, number=2, message=gca_metadata_store.MetadataStore, ) metadata_store_id = proto.Field(proto.STRING, number=3) @@ -113,8 +113,8 @@ class CreateMetadataStoreOperationMetadata(proto.Message): MetadataStore. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -181,8 +181,8 @@ class ListMetadataStoresResponse(proto.Message): def raw_page(self): return self - metadata_stores = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_metadata_store.MetadataStore, + metadata_stores = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_metadata_store.MetadataStore, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -219,8 +219,8 @@ class DeleteMetadataStoreOperationMetadata(proto.Message): MetadataStore. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -249,9 +249,7 @@ class CreateArtifactRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - artifact = proto.Field(proto.MESSAGE, number=2, - message=gca_artifact.Artifact, - ) + artifact = proto.Field(proto.MESSAGE, number=2, message=gca_artifact.Artifact,) artifact_id = proto.Field(proto.STRING, number=3) @@ -324,8 +322,8 @@ class ListArtifactsResponse(proto.Message): def raw_page(self): return self - artifacts = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_artifact.Artifact, + artifacts = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_artifact.Artifact, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -353,13 +351,9 @@ class UpdateArtifactRequest(proto.Message): created. In this situation, ``update_mask`` is ignored. """ - artifact = proto.Field(proto.MESSAGE, number=1, - message=gca_artifact.Artifact, - ) + artifact = proto.Field(proto.MESSAGE, number=1, message=gca_artifact.Artifact,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) allow_missing = proto.Field(proto.BOOL, number=3) @@ -389,9 +383,7 @@ class CreateContextRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - context = proto.Field(proto.MESSAGE, number=2, - message=gca_context.Context, - ) + context = proto.Field(proto.MESSAGE, number=2, message=gca_context.Context,) context_id = proto.Field(proto.STRING, number=3) @@ -464,8 +456,8 @@ class ListContextsResponse(proto.Message): def raw_page(self): return self - contexts = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_context.Context, + contexts = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_context.Context, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -492,13 +484,9 @@ class UpdateContextRequest(proto.Message): created. In this situation, ``update_mask`` is ignored. """ - context = proto.Field(proto.MESSAGE, number=1, - message=gca_context.Context, - ) + context = proto.Field(proto.MESSAGE, number=1, message=gca_context.Context,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) allow_missing = proto.Field(proto.BOOL, number=3) @@ -624,9 +612,7 @@ class CreateExecutionRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - execution = proto.Field(proto.MESSAGE, number=2, - message=gca_execution.Execution, - ) + execution = proto.Field(proto.MESSAGE, number=2, message=gca_execution.Execution,) execution_id = proto.Field(proto.STRING, number=3) @@ -705,8 +691,8 @@ class ListExecutionsResponse(proto.Message): def raw_page(self): return self - executions = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_execution.Execution, + executions = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_execution.Execution, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -734,13 +720,9 @@ class UpdateExecutionRequest(proto.Message): be created. In this situation, ``update_mask`` is ignored. """ - execution = proto.Field(proto.MESSAGE, number=1, - message=gca_execution.Execution, - ) + execution = proto.Field(proto.MESSAGE, number=1, message=gca_execution.Execution,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) allow_missing = proto.Field(proto.BOOL, number=3) @@ -761,9 +743,7 @@ class AddExecutionEventsRequest(proto.Message): execution = proto.Field(proto.STRING, number=1) - events = proto.RepeatedField(proto.MESSAGE, number=2, - message=event.Event, - ) + events = proto.RepeatedField(proto.MESSAGE, number=2, message=event.Event,) class AddExecutionEventsResponse(proto.Message): @@ -814,8 +794,8 @@ class CreateMetadataSchemaRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - metadata_schema = proto.Field(proto.MESSAGE, number=2, - message=gca_metadata_schema.MetadataSchema, + metadata_schema = proto.Field( + proto.MESSAGE, number=2, message=gca_metadata_schema.MetadataSchema, ) metadata_schema_id = proto.Field(proto.STRING, number=3) @@ -890,8 +870,8 @@ class ListMetadataSchemasResponse(proto.Message): def raw_page(self): return self - metadata_schemas = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_metadata_schema.MetadataSchema, + metadata_schemas = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_metadata_schema.MetadataSchema, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_store.py b/google/cloud/aiplatform_v1beta1/types/metadata_store.py index da4704e31d..19456d92eb 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_store.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_store.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'MetadataStore', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"MetadataStore",}, ) @@ -53,16 +50,12 @@ class MetadataStore(proto.Message): name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - encryption_spec = proto.Field(proto.MESSAGE, number=5, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=5, message=gca_encryption_spec.EncryptionSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py index 07f9565af6..9a695ea349 100644 --- a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py +++ b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'MigratableResource', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"MigratableResource",}, ) @@ -55,6 +52,7 @@ class MigratableResource(proto.Message): Output only. Timestamp when this MigratableResource was last updated. """ + class MlEngineModelVersion(proto.Message): r"""Represents one model Version in ml.googleapis.com. @@ -123,6 +121,7 @@ class DataLabelingDataset(proto.Message): datalabeling.googleapis.com belongs to the data labeling Dataset. """ + class DataLabelingAnnotatedDataset(proto.Message): r"""Represents one AnnotatedDataset in datalabeling.googleapis.com. @@ -145,32 +144,34 @@ class DataLabelingAnnotatedDataset(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=4) - data_labeling_annotated_datasets = proto.RepeatedField(proto.MESSAGE, number=3, - message='MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset', + data_labeling_annotated_datasets = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset", ) - ml_engine_model_version = proto.Field(proto.MESSAGE, number=1, oneof='resource', - message=MlEngineModelVersion, + ml_engine_model_version = proto.Field( + proto.MESSAGE, number=1, oneof="resource", message=MlEngineModelVersion, ) - automl_model = proto.Field(proto.MESSAGE, number=2, oneof='resource', - message=AutomlModel, + automl_model = proto.Field( + proto.MESSAGE, number=2, oneof="resource", message=AutomlModel, ) - automl_dataset = proto.Field(proto.MESSAGE, number=3, oneof='resource', - message=AutomlDataset, + automl_dataset = proto.Field( + proto.MESSAGE, number=3, oneof="resource", message=AutomlDataset, ) - data_labeling_dataset = proto.Field(proto.MESSAGE, number=4, oneof='resource', - message=DataLabelingDataset, + data_labeling_dataset = proto.Field( + proto.MESSAGE, number=4, oneof="resource", message=DataLabelingDataset, ) - last_migrate_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, + last_migrate_time = proto.Field( + proto.MESSAGE, number=5, message=timestamp.Timestamp, ) - last_update_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, + last_update_time = proto.Field( + proto.MESSAGE, number=6, message=timestamp.Timestamp, ) diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py index ec23daf2ff..de4c9466f6 100644 --- a/google/cloud/aiplatform_v1beta1/types/migration_service.py +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -18,21 +18,23 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import migratable_resource as gca_migratable_resource +from google.cloud.aiplatform_v1beta1.types import ( + migratable_resource as gca_migratable_resource, +) from google.cloud.aiplatform_v1beta1.types import operation from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'BatchMigrateResourcesRequest', - 'MigrateResourceRequest', - 'BatchMigrateResourcesResponse', - 'MigrateResourceResponse', - 'BatchMigrateResourcesOperationMetadata', + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", }, ) @@ -99,8 +101,8 @@ class SearchMigratableResourcesResponse(proto.Message): def raw_page(self): return self - migratable_resources = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_migratable_resource.MigratableResource, + migratable_resources = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_migratable_resource.MigratableResource, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -123,8 +125,8 @@ class BatchMigrateResourcesRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - migrate_resource_requests = proto.RepeatedField(proto.MESSAGE, number=2, - message='MigrateResourceRequest', + migrate_resource_requests = proto.RepeatedField( + proto.MESSAGE, number=2, message="MigrateResourceRequest", ) @@ -148,6 +150,7 @@ class MigrateResourceRequest(proto.Message): datalabeling.googleapis.com to AI Platform's Dataset. """ + class MigrateMlEngineModelVersionConfig(proto.Message): r"""Config for migrating version in ml.googleapis.com to AI Platform's Model. @@ -235,6 +238,7 @@ class MigrateDataLabelingDatasetConfig(proto.Message): AnnotatedDatasets have to belong to the datalabeling Dataset. """ + class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): r"""Config for migrating AnnotatedDataset in datalabeling.googleapis.com to AI Platform's SavedQuery. @@ -252,23 +256,31 @@ class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=2) - migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField(proto.MESSAGE, number=3, - message='MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig', + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig", ) - migrate_ml_engine_model_version_config = proto.Field(proto.MESSAGE, number=1, oneof='request', + migrate_ml_engine_model_version_config = proto.Field( + proto.MESSAGE, + number=1, + oneof="request", message=MigrateMlEngineModelVersionConfig, ) - migrate_automl_model_config = proto.Field(proto.MESSAGE, number=2, oneof='request', - message=MigrateAutomlModelConfig, + migrate_automl_model_config = proto.Field( + proto.MESSAGE, number=2, oneof="request", message=MigrateAutomlModelConfig, ) - migrate_automl_dataset_config = proto.Field(proto.MESSAGE, number=3, oneof='request', - message=MigrateAutomlDatasetConfig, + migrate_automl_dataset_config = proto.Field( + proto.MESSAGE, number=3, oneof="request", message=MigrateAutomlDatasetConfig, ) - migrate_data_labeling_dataset_config = proto.Field(proto.MESSAGE, number=4, oneof='request', + migrate_data_labeling_dataset_config = proto.Field( + proto.MESSAGE, + number=4, + oneof="request", message=MigrateDataLabelingDatasetConfig, ) @@ -282,8 +294,8 @@ class BatchMigrateResourcesResponse(proto.Message): Successfully migrated resources. """ - migrate_resource_responses = proto.RepeatedField(proto.MESSAGE, number=1, - message='MigrateResourceResponse', + migrate_resource_responses = proto.RepeatedField( + proto.MESSAGE, number=1, message="MigrateResourceResponse", ) @@ -301,12 +313,12 @@ class MigrateResourceResponse(proto.Message): datalabeling.googleapis.com. """ - dataset = proto.Field(proto.STRING, number=1, oneof='migrated_resource') + dataset = proto.Field(proto.STRING, number=1, oneof="migrated_resource") - model = proto.Field(proto.STRING, number=2, oneof='migrated_resource') + model = proto.Field(proto.STRING, number=2, oneof="migrated_resource") - migratable_resource = proto.Field(proto.MESSAGE, number=3, - message=gca_migratable_resource.MigratableResource, + migratable_resource = proto.Field( + proto.MESSAGE, number=3, message=gca_migratable_resource.MigratableResource, ) @@ -321,6 +333,7 @@ class BatchMigrateResourcesOperationMetadata(proto.Message): Partial results that reflect the latest migration operation progress. """ + class PartialResult(proto.Message): r"""Represents a partial result in batch migration operation for one ``MigrateResourceRequest``. @@ -338,24 +351,24 @@ class PartialResult(proto.Message): [MigrateResourceRequest.migrate_resource_requests][]. """ - error = proto.Field(proto.MESSAGE, number=2, oneof='result', - message=status.Status, + error = proto.Field( + proto.MESSAGE, number=2, oneof="result", message=status.Status, ) - model = proto.Field(proto.STRING, number=3, oneof='result') + model = proto.Field(proto.STRING, number=3, oneof="result") - dataset = proto.Field(proto.STRING, number=4, oneof='result') + dataset = proto.Field(proto.STRING, number=4, oneof="result") - request = proto.Field(proto.MESSAGE, number=1, - message='MigrateResourceRequest', + request = proto.Field( + proto.MESSAGE, number=1, message="MigrateResourceRequest", ) - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - partial_results = proto.RepeatedField(proto.MESSAGE, number=2, - message=PartialResult, + partial_results = proto.RepeatedField( + proto.MESSAGE, number=2, message=PartialResult, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index aaa87f85bb..4dcf6baefa 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -27,13 +27,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Model', - 'PredictSchemata', - 'ModelContainerSpec', - 'Port', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"Model", "PredictSchemata", "ModelContainerSpec", "Port",}, ) @@ -254,6 +249,7 @@ class Model(proto.Message): Model. If set, this Model and all sub-resources of this Model will be secured by this key. """ + class DeploymentResourcesType(proto.Enum): r"""Identifies a type of Model's prediction resources.""" DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = 0 @@ -290,6 +286,7 @@ class ExportFormat(proto.Message): Output only. The content of this Model that may be exported. """ + class ExportableContent(proto.Enum): r"""The Model content that can be exported.""" EXPORTABLE_CONTENT_UNSPECIFIED = 0 @@ -298,8 +295,8 @@ class ExportableContent(proto.Enum): id = proto.Field(proto.STRING, number=1) - exportable_contents = proto.RepeatedField(proto.ENUM, number=2, - enum='Model.ExportFormat.ExportableContent', + exportable_contents = proto.RepeatedField( + proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", ) name = proto.Field(proto.STRING, number=1) @@ -308,58 +305,48 @@ class ExportableContent(proto.Enum): description = proto.Field(proto.STRING, number=3) - predict_schemata = proto.Field(proto.MESSAGE, number=4, - message='PredictSchemata', - ) + predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) metadata_schema_uri = proto.Field(proto.STRING, number=5) - metadata = proto.Field(proto.MESSAGE, number=6, - message=struct.Value, - ) + metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - supported_export_formats = proto.RepeatedField(proto.MESSAGE, number=20, - message=ExportFormat, + supported_export_formats = proto.RepeatedField( + proto.MESSAGE, number=20, message=ExportFormat, ) training_pipeline = proto.Field(proto.STRING, number=7) - container_spec = proto.Field(proto.MESSAGE, number=9, - message='ModelContainerSpec', - ) + container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) artifact_uri = proto.Field(proto.STRING, number=26) - supported_deployment_resources_types = proto.RepeatedField(proto.ENUM, number=10, - enum=DeploymentResourcesType, + supported_deployment_resources_types = proto.RepeatedField( + proto.ENUM, number=10, enum=DeploymentResourcesType, ) supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) - create_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - deployed_models = proto.RepeatedField(proto.MESSAGE, number=15, - message=deployed_model_ref.DeployedModelRef, + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, ) - explanation_spec = proto.Field(proto.MESSAGE, number=23, - message=explanation.ExplanationSpec, + explanation_spec = proto.Field( + proto.MESSAGE, number=23, message=explanation.ExplanationSpec, ) etag = proto.Field(proto.STRING, number=16) labels = proto.MapField(proto.STRING, proto.STRING, number=17) - encryption_spec = proto.Field(proto.MESSAGE, number=24, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, ) @@ -667,13 +654,9 @@ class ModelContainerSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, - message=env_var.EnvVar, - ) + env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) - ports = proto.RepeatedField(proto.MESSAGE, number=5, - message='Port', - ) + ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) predict_route = proto.Field(proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py index c500a28a8b..5a7015777b 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py +++ b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py @@ -28,14 +28,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'ModelDeploymentMonitoringObjectiveType', - 'ModelDeploymentMonitoringJob', - 'ModelDeploymentMonitoringBigQueryTable', - 'ModelDeploymentMonitoringObjectiveConfig', - 'ModelDeploymentMonitoringScheduleConfig', - 'ModelMonitoringStatsAnomalies', + "ModelDeploymentMonitoringObjectiveType", + "ModelDeploymentMonitoringJob", + "ModelDeploymentMonitoringBigQueryTable", + "ModelDeploymentMonitoringObjectiveConfig", + "ModelDeploymentMonitoringScheduleConfig", + "ModelMonitoringStatsAnomalies", }, ) @@ -157,6 +157,7 @@ class ModelDeploymentMonitoringJob(proto.Message): stats_anomalies_base_directory (google.cloud.aiplatform_v1beta1.types.GcsDestination): Stats anomalies base folder path. """ + class MonitoringScheduleState(proto.Enum): r"""The state to Specify the monitoring pipeline.""" MONITORING_SCHEDULE_STATE_UNSPECIFIED = 0 @@ -170,62 +171,52 @@ class MonitoringScheduleState(proto.Enum): endpoint = proto.Field(proto.STRING, number=3) - state = proto.Field(proto.ENUM, number=4, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=4, enum=job_state.JobState,) - schedule_state = proto.Field(proto.ENUM, number=5, - enum=MonitoringScheduleState, - ) + schedule_state = proto.Field(proto.ENUM, number=5, enum=MonitoringScheduleState,) - model_deployment_monitoring_objective_configs = proto.RepeatedField(proto.MESSAGE, number=6, - message='ModelDeploymentMonitoringObjectiveConfig', + model_deployment_monitoring_objective_configs = proto.RepeatedField( + proto.MESSAGE, number=6, message="ModelDeploymentMonitoringObjectiveConfig", ) - model_deployment_monitoring_schedule_config = proto.Field(proto.MESSAGE, number=7, - message='ModelDeploymentMonitoringScheduleConfig', + model_deployment_monitoring_schedule_config = proto.Field( + proto.MESSAGE, number=7, message="ModelDeploymentMonitoringScheduleConfig", ) - logging_sampling_strategy = proto.Field(proto.MESSAGE, number=8, - message=model_monitoring.SamplingStrategy, + logging_sampling_strategy = proto.Field( + proto.MESSAGE, number=8, message=model_monitoring.SamplingStrategy, ) - model_monitoring_alert_config = proto.Field(proto.MESSAGE, number=15, - message=model_monitoring.ModelMonitoringAlertConfig, + model_monitoring_alert_config = proto.Field( + proto.MESSAGE, number=15, message=model_monitoring.ModelMonitoringAlertConfig, ) predict_instance_schema_uri = proto.Field(proto.STRING, number=9) - sample_predict_instance = proto.Field(proto.MESSAGE, number=19, - message=struct.Value, + sample_predict_instance = proto.Field( + proto.MESSAGE, number=19, message=struct.Value, ) analysis_instance_schema_uri = proto.Field(proto.STRING, number=16) - bigquery_tables = proto.RepeatedField(proto.MESSAGE, number=10, - message='ModelDeploymentMonitoringBigQueryTable', + bigquery_tables = proto.RepeatedField( + proto.MESSAGE, number=10, message="ModelDeploymentMonitoringBigQueryTable", ) - log_ttl = proto.Field(proto.MESSAGE, number=17, - message=duration.Duration, - ) + log_ttl = proto.Field(proto.MESSAGE, number=17, message=duration.Duration,) labels = proto.MapField(proto.STRING, proto.STRING, number=11) - create_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - next_schedule_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, + next_schedule_time = proto.Field( + proto.MESSAGE, number=14, message=timestamp.Timestamp, ) - stats_anomalies_base_directory = proto.Field(proto.MESSAGE, number=20, - message=io.GcsDestination, + stats_anomalies_base_directory = proto.Field( + proto.MESSAGE, number=20, message=io.GcsDestination, ) @@ -244,6 +235,7 @@ class ModelDeploymentMonitoringBigQueryTable(proto.Message): their own query & analysis. Format: ``bq://.model_deployment_monitoring_._`` """ + class LogSource(proto.Enum): r"""Indicates where does the log come from.""" LOG_SOURCE_UNSPECIFIED = 0 @@ -256,13 +248,9 @@ class LogType(proto.Enum): PREDICT = 1 EXPLAIN = 2 - log_source = proto.Field(proto.ENUM, number=1, - enum=LogSource, - ) + log_source = proto.Field(proto.ENUM, number=1, enum=LogSource,) - log_type = proto.Field(proto.ENUM, number=2, - enum=LogType, - ) + log_type = proto.Field(proto.ENUM, number=2, enum=LogType,) bigquery_table_path = proto.Field(proto.STRING, number=3) @@ -281,7 +269,9 @@ class ModelDeploymentMonitoringObjectiveConfig(proto.Message): deployed_model_id = proto.Field(proto.STRING, number=1) - objective_config = proto.Field(proto.MESSAGE, number=2, + objective_config = proto.Field( + proto.MESSAGE, + number=2, message=model_monitoring.ModelMonitoringObjectiveConfig, ) @@ -296,9 +286,7 @@ class ModelDeploymentMonitoringScheduleConfig(proto.Message): hour. """ - monitor_interval = proto.Field(proto.MESSAGE, number=1, - message=duration.Duration, - ) + monitor_interval = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) class ModelMonitoringStatsAnomalies(proto.Message): @@ -316,6 +304,7 @@ class ModelMonitoringStatsAnomalies(proto.Message): A list of historical Stats and Anomalies generated for all Features. """ + class FeatureHistoricStatsAnomalies(proto.Message): r"""Historical Stats (and Anomalies) for a specific Feature. @@ -333,28 +322,32 @@ class FeatureHistoricStatsAnomalies(proto.Message): feature_display_name = proto.Field(proto.STRING, number=1) - threshold = proto.Field(proto.MESSAGE, number=3, - message=model_monitoring.ThresholdConfig, + threshold = proto.Field( + proto.MESSAGE, number=3, message=model_monitoring.ThresholdConfig, ) - training_stats = proto.Field(proto.MESSAGE, number=4, + training_stats = proto.Field( + proto.MESSAGE, + number=4, message=feature_monitoring_stats.FeatureStatsAnomaly, ) - prediction_stats = proto.RepeatedField(proto.MESSAGE, number=5, + prediction_stats = proto.RepeatedField( + proto.MESSAGE, + number=5, message=feature_monitoring_stats.FeatureStatsAnomaly, ) - objective = proto.Field(proto.ENUM, number=1, - enum='ModelDeploymentMonitoringObjectiveType', + objective = proto.Field( + proto.ENUM, number=1, enum="ModelDeploymentMonitoringObjectiveType", ) deployed_model_id = proto.Field(proto.STRING, number=2) anomaly_count = proto.Field(proto.INT32, number=3) - feature_stats = proto.RepeatedField(proto.MESSAGE, number=4, - message=FeatureHistoricStatsAnomalies, + feature_stats = proto.RepeatedField( + proto.MESSAGE, number=4, message=FeatureHistoricStatsAnomalies, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py index d0a4a5a146..391bc38cf4 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ModelEvaluation', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluation",}, ) @@ -74,6 +71,7 @@ class ModelEvaluation(proto.Message): that are used for explaining the predicted values on the evaluated data. """ + class ModelEvaluationExplanationSpec(proto.Message): r""" @@ -91,30 +89,26 @@ class ModelEvaluationExplanationSpec(proto.Message): explanation_type = proto.Field(proto.STRING, number=1) - explanation_spec = proto.Field(proto.MESSAGE, number=2, - message=explanation.ExplanationSpec, + explanation_spec = proto.Field( + proto.MESSAGE, number=2, message=explanation.ExplanationSpec, ) name = proto.Field(proto.STRING, number=1) metrics_schema_uri = proto.Field(proto.STRING, number=2) - metrics = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) slice_dimensions = proto.RepeatedField(proto.STRING, number=5) - model_explanation = proto.Field(proto.MESSAGE, number=8, - message=explanation.ModelExplanation, + model_explanation = proto.Field( + proto.MESSAGE, number=8, message=explanation.ModelExplanation, ) - explanation_specs = proto.RepeatedField(proto.MESSAGE, number=9, - message=ModelEvaluationExplanationSpec, + explanation_specs = proto.RepeatedField( + proto.MESSAGE, number=9, message=ModelEvaluationExplanationSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py index 3895dd1170..2d66e29a9f 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ModelEvaluationSlice', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluationSlice",}, ) @@ -57,6 +54,7 @@ class ModelEvaluationSlice(proto.Message): Output only. Timestamp when this ModelEvaluationSlice was created. """ + class Slice(proto.Message): r"""Definition of a slice. @@ -81,19 +79,13 @@ class Slice(proto.Message): name = proto.Field(proto.STRING, number=1) - slice_ = proto.Field(proto.MESSAGE, number=2, - message=Slice, - ) + slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) metrics_schema_uri = proto.Field(proto.STRING, number=3) - metrics = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_monitoring.py b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py index f57417be64..fd605d8265 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_monitoring.py +++ b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py @@ -22,12 +22,12 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'ModelMonitoringObjectiveConfig', - 'ModelMonitoringAlertConfig', - 'ThresholdConfig', - 'SamplingStrategy', + "ModelMonitoringObjectiveConfig", + "ModelMonitoringAlertConfig", + "ThresholdConfig", + "SamplingStrategy", }, ) @@ -47,6 +47,7 @@ class ModelMonitoringObjectiveConfig(proto.Message): prediction_drift_detection_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig): The config for drift of prediction data. """ + class TrainingDataset(proto.Message): r"""Training Dataset information. @@ -80,22 +81,22 @@ class TrainingDataset(proto.Message): dataset. """ - dataset = proto.Field(proto.STRING, number=3, oneof='data_source') + dataset = proto.Field(proto.STRING, number=3, oneof="data_source") - gcs_source = proto.Field(proto.MESSAGE, number=4, oneof='data_source', - message=io.GcsSource, + gcs_source = proto.Field( + proto.MESSAGE, number=4, oneof="data_source", message=io.GcsSource, ) - bigquery_source = proto.Field(proto.MESSAGE, number=5, oneof='data_source', - message=io.BigQuerySource, + bigquery_source = proto.Field( + proto.MESSAGE, number=5, oneof="data_source", message=io.BigQuerySource, ) data_format = proto.Field(proto.STRING, number=2) target_field = proto.Field(proto.STRING, number=6) - logging_sampling_strategy = proto.Field(proto.MESSAGE, number=7, - message='SamplingStrategy', + logging_sampling_strategy = proto.Field( + proto.MESSAGE, number=7, message="SamplingStrategy", ) class TrainingPredictionSkewDetectionConfig(proto.Message): @@ -113,8 +114,8 @@ class TrainingPredictionSkewDetectionConfig(proto.Message): training and prediction feature. """ - skew_thresholds = proto.MapField(proto.STRING, proto.MESSAGE, number=1, - message='ThresholdConfig', + skew_thresholds = proto.MapField( + proto.STRING, proto.MESSAGE, number=1, message="ThresholdConfig", ) class PredictionDriftDetectionConfig(proto.Message): @@ -130,20 +131,18 @@ class PredictionDriftDetectionConfig(proto.Message): time windws. """ - drift_thresholds = proto.MapField(proto.STRING, proto.MESSAGE, number=1, - message='ThresholdConfig', + drift_thresholds = proto.MapField( + proto.STRING, proto.MESSAGE, number=1, message="ThresholdConfig", ) - training_dataset = proto.Field(proto.MESSAGE, number=1, - message=TrainingDataset, - ) + training_dataset = proto.Field(proto.MESSAGE, number=1, message=TrainingDataset,) - training_prediction_skew_detection_config = proto.Field(proto.MESSAGE, number=2, - message=TrainingPredictionSkewDetectionConfig, + training_prediction_skew_detection_config = proto.Field( + proto.MESSAGE, number=2, message=TrainingPredictionSkewDetectionConfig, ) - prediction_drift_detection_config = proto.Field(proto.MESSAGE, number=3, - message=PredictionDriftDetectionConfig, + prediction_drift_detection_config = proto.Field( + proto.MESSAGE, number=3, message=PredictionDriftDetectionConfig, ) @@ -154,6 +153,7 @@ class ModelMonitoringAlertConfig(proto.Message): email_alert_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringAlertConfig.EmailAlertConfig): Email alert config. """ + class EmailAlertConfig(proto.Message): r"""The config for email alert. @@ -164,8 +164,8 @@ class EmailAlertConfig(proto.Message): user_emails = proto.RepeatedField(proto.STRING, number=1) - email_alert_config = proto.Field(proto.MESSAGE, number=1, oneof='alert', - message=EmailAlertConfig, + email_alert_config = proto.Field( + proto.MESSAGE, number=1, oneof="alert", message=EmailAlertConfig, ) @@ -188,7 +188,7 @@ class ThresholdConfig(proto.Message): will be triggered for that feature. """ - value = proto.Field(proto.DOUBLE, number=1, oneof='threshold') + value = proto.Field(proto.DOUBLE, number=1, oneof="threshold") class SamplingStrategy(proto.Message): @@ -201,6 +201,7 @@ class SamplingStrategy(proto.Message): Random sample config. Will support more sampling strategies later. """ + class RandomSampleConfig(proto.Message): r"""Requests are randomly selected. @@ -211,8 +212,8 @@ class RandomSampleConfig(proto.Message): sample_rate = proto.Field(proto.DOUBLE, number=1) - random_sample_config = proto.Field(proto.MESSAGE, number=1, - message=RandomSampleConfig, + random_sample_config = proto.Field( + proto.MESSAGE, number=1, message=RandomSampleConfig, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index 46b5328166..e0d8e148ab 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -27,25 +27,25 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'UploadModelRequest', - 'UploadModelOperationMetadata', - 'UploadModelResponse', - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'UpdateModelRequest', - 'DeleteModelRequest', - 'ExportModelRequest', - 'ExportModelOperationMetadata', - 'ExportModelResponse', - 'GetModelEvaluationRequest', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'GetModelEvaluationSliceRequest', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', + "UploadModelRequest", + "UploadModelOperationMetadata", + "UploadModelResponse", + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "DeleteModelRequest", + "ExportModelRequest", + "ExportModelOperationMetadata", + "ExportModelResponse", + "GetModelEvaluationRequest", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "GetModelEvaluationSliceRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", }, ) @@ -65,9 +65,7 @@ class UploadModelRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.MESSAGE, number=2, - message=gca_model.Model, - ) + model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) class UploadModelOperationMetadata(proto.Message): @@ -80,8 +78,8 @@ class UploadModelOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -161,9 +159,7 @@ class ListModelsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelsResponse(proto.Message): @@ -183,9 +179,7 @@ class ListModelsResponse(proto.Message): def raw_page(self): return self - models = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_model.Model, - ) + models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) next_page_token = proto.Field(proto.STRING, number=2) @@ -204,13 +198,9 @@ class UpdateModelRequest(proto.Message): `FieldMask `__. """ - model = proto.Field(proto.MESSAGE, number=1, - message=gca_model.Model, - ) + model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteModelRequest(proto.Message): @@ -239,6 +229,7 @@ class ExportModelRequest(proto.Message): Required. The desired output location and configuration. """ + class OutputConfig(proto.Message): r"""Output configuration for the Model export. @@ -270,19 +261,17 @@ class OutputConfig(proto.Message): export_format_id = proto.Field(proto.STRING, number=1) - artifact_destination = proto.Field(proto.MESSAGE, number=3, - message=io.GcsDestination, + artifact_destination = proto.Field( + proto.MESSAGE, number=3, message=io.GcsDestination, ) - image_destination = proto.Field(proto.MESSAGE, number=4, - message=io.ContainerRegistryDestination, + image_destination = proto.Field( + proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) - output_config = proto.Field(proto.MESSAGE, number=2, - message=OutputConfig, - ) + output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) class ExportModelOperationMetadata(proto.Message): @@ -297,6 +286,7 @@ class ExportModelOperationMetadata(proto.Message): Output only. Information further describing the output of this Model export. """ + class OutputInfo(proto.Message): r"""Further describes the output of the ExportModel. Supplements ``ExportModelRequest.OutputConfig``. @@ -318,13 +308,11 @@ class OutputInfo(proto.Message): image_output_uri = proto.Field(proto.STRING, number=3) - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - output_info = proto.Field(proto.MESSAGE, number=2, - message=OutputInfo, - ) + output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) class ExportModelResponse(proto.Message): @@ -378,9 +366,7 @@ class ListModelEvaluationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelEvaluationsResponse(proto.Message): @@ -401,8 +387,8 @@ class ListModelEvaluationsResponse(proto.Message): def raw_page(self): return self - model_evaluations = proto.RepeatedField(proto.MESSAGE, number=1, - message=model_evaluation.ModelEvaluation, + model_evaluations = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -455,9 +441,7 @@ class ListModelEvaluationSlicesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelEvaluationSlicesResponse(proto.Message): @@ -478,8 +462,8 @@ class ListModelEvaluationSlicesResponse(proto.Message): def raw_page(self): return self - model_evaluation_slices = proto.RepeatedField(proto.MESSAGE, number=1, - message=model_evaluation_slice.ModelEvaluationSlice, + model_evaluation_slices = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index 887e903ff2..90565867e8 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -23,11 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'GenericOperationMetadata', - 'DeleteOperationMetadata', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"GenericOperationMetadata", "DeleteOperationMetadata",}, ) @@ -51,17 +48,13 @@ class GenericOperationMetadata(proto.Message): finish time. """ - partial_failures = proto.RepeatedField(proto.MESSAGE, number=1, - message=status.Status, + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=1, message=status.Status, ) - create_time = proto.Field(proto.MESSAGE, number=2, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) class DeleteOperationMetadata(proto.Message): @@ -72,8 +65,8 @@ class DeleteOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message='GenericOperationMetadata', + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message="GenericOperationMetadata", ) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index a5add3f9ca..b06361dfa9 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -18,19 +18,21 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateTrainingPipelineRequest', - 'GetTrainingPipelineRequest', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'DeleteTrainingPipelineRequest', - 'CancelTrainingPipelineRequest', + "CreateTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "DeleteTrainingPipelineRequest", + "CancelTrainingPipelineRequest", }, ) @@ -50,8 +52,8 @@ class CreateTrainingPipelineRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - training_pipeline = proto.Field(proto.MESSAGE, number=2, - message=gca_training_pipeline.TrainingPipeline, + training_pipeline = proto.Field( + proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, ) @@ -113,9 +115,7 @@ class ListTrainingPipelinesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListTrainingPipelinesResponse(proto.Message): @@ -136,8 +136,8 @@ class ListTrainingPipelinesResponse(proto.Message): def raw_page(self): return self - training_pipelines = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_training_pipeline.TrainingPipeline, + training_pipelines = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py index b04954f602..cede653bd6 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'PipelineState', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"PipelineState",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index 24011ca24d..f7abe9e3e2 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -23,12 +23,12 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'PredictRequest', - 'PredictResponse', - 'ExplainRequest', - 'ExplainResponse', + "PredictRequest", + "PredictResponse", + "ExplainRequest", + "ExplainResponse", }, ) @@ -65,13 +65,9 @@ class PredictRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, - message=struct.Value, - ) + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) class PredictResponse(proto.Message): @@ -91,9 +87,7 @@ class PredictResponse(proto.Message): served this prediction. """ - predictions = proto.RepeatedField(proto.MESSAGE, number=1, - message=struct.Value, - ) + predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) deployed_model_id = proto.Field(proto.STRING, number=2) @@ -145,16 +139,12 @@ class ExplainRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, - message=struct.Value, - ) + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + parameters = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - explanation_spec_override = proto.Field(proto.MESSAGE, number=5, - message=explanation.ExplanationSpecOverride, + explanation_spec_override = proto.Field( + proto.MESSAGE, number=5, message=explanation.ExplanationSpecOverride, ) deployed_model_id = proto.Field(proto.STRING, number=3) @@ -181,15 +171,13 @@ class ExplainResponse(proto.Message): ``PredictResponse.predictions``. """ - explanations = proto.RepeatedField(proto.MESSAGE, number=1, - message=explanation.Explanation, + explanations = proto.RepeatedField( + proto.MESSAGE, number=1, message=explanation.Explanation, ) deployed_model_id = proto.Field(proto.STRING, number=2) - predictions = proto.RepeatedField(proto.MESSAGE, number=3, - message=struct.Value, - ) + predictions = proto.RepeatedField(proto.MESSAGE, number=3, message=struct.Value,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py index f75416157b..4ac8c6a709 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'SpecialistPool', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"SpecialistPool",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index aa9e9235ef..3ed6593bd6 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -24,16 +24,16 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateSpecialistPoolRequest', - 'CreateSpecialistPoolOperationMetadata', - 'GetSpecialistPoolRequest', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'DeleteSpecialistPoolRequest', - 'UpdateSpecialistPoolRequest', - 'UpdateSpecialistPoolOperationMetadata', + "CreateSpecialistPoolRequest", + "CreateSpecialistPoolOperationMetadata", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "DeleteSpecialistPoolRequest", + "UpdateSpecialistPoolRequest", + "UpdateSpecialistPoolOperationMetadata", }, ) @@ -53,8 +53,8 @@ class CreateSpecialistPoolRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - specialist_pool = proto.Field(proto.MESSAGE, number=2, - message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field( + proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, ) @@ -67,8 +67,8 @@ class CreateSpecialistPoolOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -113,9 +113,7 @@ class ListSpecialistPoolsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=3) - read_mask = proto.Field(proto.MESSAGE, number=4, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) class ListSpecialistPoolsResponse(proto.Message): @@ -134,8 +132,8 @@ class ListSpecialistPoolsResponse(proto.Message): def raw_page(self): return self - specialist_pools = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_specialist_pool.SpecialistPool, + specialist_pools = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -175,13 +173,11 @@ class UpdateSpecialistPoolRequest(proto.Message): resource. """ - specialist_pool = proto.Field(proto.MESSAGE, number=1, - message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateSpecialistPoolOperationMetadata(proto.Message): @@ -199,8 +195,8 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): specialist_pool = proto.Field(proto.STRING, number=1) - generic_metadata = proto.Field(proto.MESSAGE, number=2, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index f1f28d8669..49831d7718 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -24,13 +24,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Study', - 'Trial', - 'StudySpec', - 'Measurement', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"Study", "Trial", "StudySpec", "Measurement",}, ) @@ -57,6 +52,7 @@ class Study(proto.Message): Study is inactive. This should be empty if a study is ACTIVE or COMPLETED. """ + class State(proto.Enum): r"""Describes the Study state.""" STATE_UNSPECIFIED = 0 @@ -68,17 +64,11 @@ class State(proto.Enum): display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=3, - message='StudySpec', - ) + study_spec = proto.Field(proto.MESSAGE, number=3, message="StudySpec",) - state = proto.Field(proto.ENUM, number=4, - enum=State, - ) + state = proto.Field(proto.ENUM, number=4, enum=State,) - create_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) inactive_reason = proto.Field(proto.STRING, number=6) @@ -131,6 +121,7 @@ class Trial(proto.Message): Trial. It's set for a HyperparameterTuningJob's Trial. """ + class State(proto.Enum): r"""Describes a Trial state.""" STATE_UNSPECIFIED = 0 @@ -158,37 +149,23 @@ class Parameter(proto.Message): parameter_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.MESSAGE, number=2, - message=struct.Value, - ) + value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) name = proto.Field(proto.STRING, number=1) id = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=3, - enum=State, - ) + state = proto.Field(proto.ENUM, number=3, enum=State,) - parameters = proto.RepeatedField(proto.MESSAGE, number=4, - message=Parameter, - ) + parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) - final_measurement = proto.Field(proto.MESSAGE, number=5, - message='Measurement', - ) + final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) - measurements = proto.RepeatedField(proto.MESSAGE, number=6, - message='Measurement', - ) + measurements = proto.RepeatedField(proto.MESSAGE, number=6, message="Measurement",) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) client_id = proto.Field(proto.STRING, number=9) @@ -225,6 +202,7 @@ class StudySpec(proto.Message): Describe which measurement selection type will be used """ + class Algorithm(proto.Enum): r"""The available search algorithms for the Study.""" ALGORITHM_UNSPECIFIED = 0 @@ -270,6 +248,7 @@ class MetricSpec(proto.Message): Required. The optimization goal of the metric. """ + class GoalType(proto.Enum): r"""The available types of optimization goals.""" GOAL_TYPE_UNSPECIFIED = 0 @@ -278,9 +257,7 @@ class GoalType(proto.Enum): metric_id = proto.Field(proto.STRING, number=1) - goal = proto.Field(proto.ENUM, number=2, - enum='StudySpec.MetricSpec.GoalType', - ) + goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) class ParameterSpec(proto.Message): r"""Represents a single parameter to optimize. @@ -308,6 +285,7 @@ class ParameterSpec(proto.Message): If two items in conditional_parameter_specs have the same name, they must have disjoint parent_value_condition. """ + class ScaleType(proto.Enum): r"""The type of scaling that should be applied to this parameter.""" SCALE_TYPE_UNSPECIFIED = 0 @@ -390,6 +368,7 @@ class ConditionalParameterSpec(proto.Message): Required. The spec for a conditional parameter. """ + class DiscreteValueCondition(proto.Message): r"""Represents the spec to match discrete values from parent parameter. @@ -431,46 +410,69 @@ class CategoricalValueCondition(proto.Message): values = proto.RepeatedField(proto.STRING, number=1) - parent_discrete_values = proto.Field(proto.MESSAGE, number=2, oneof='parent_value_condition', - message='StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition', + parent_discrete_values = proto.Field( + proto.MESSAGE, + number=2, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition", ) - parent_int_values = proto.Field(proto.MESSAGE, number=3, oneof='parent_value_condition', - message='StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition', + parent_int_values = proto.Field( + proto.MESSAGE, + number=3, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition", ) - parent_categorical_values = proto.Field(proto.MESSAGE, number=4, oneof='parent_value_condition', - message='StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition', + parent_categorical_values = proto.Field( + proto.MESSAGE, + number=4, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition", ) - parameter_spec = proto.Field(proto.MESSAGE, number=1, - message='StudySpec.ParameterSpec', + parameter_spec = proto.Field( + proto.MESSAGE, number=1, message="StudySpec.ParameterSpec", ) - double_value_spec = proto.Field(proto.MESSAGE, number=2, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.DoubleValueSpec', + double_value_spec = proto.Field( + proto.MESSAGE, + number=2, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DoubleValueSpec", ) - integer_value_spec = proto.Field(proto.MESSAGE, number=3, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.IntegerValueSpec', + integer_value_spec = proto.Field( + proto.MESSAGE, + number=3, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.IntegerValueSpec", ) - categorical_value_spec = proto.Field(proto.MESSAGE, number=4, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.CategoricalValueSpec', + categorical_value_spec = proto.Field( + proto.MESSAGE, + number=4, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.CategoricalValueSpec", ) - discrete_value_spec = proto.Field(proto.MESSAGE, number=5, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.DiscreteValueSpec', + discrete_value_spec = proto.Field( + proto.MESSAGE, + number=5, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DiscreteValueSpec", ) parameter_id = proto.Field(proto.STRING, number=1) - scale_type = proto.Field(proto.ENUM, number=6, - enum='StudySpec.ParameterSpec.ScaleType', + scale_type = proto.Field( + proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", ) - conditional_parameter_specs = proto.RepeatedField(proto.MESSAGE, number=10, - message='StudySpec.ParameterSpec.ConditionalParameterSpec', + conditional_parameter_specs = proto.RepeatedField( + proto.MESSAGE, + number=10, + message="StudySpec.ParameterSpec.ConditionalParameterSpec", ) class DecayCurveAutomatedStoppingSpec(proto.Message): @@ -559,36 +561,37 @@ class ConvexStopConfig(proto.Message): use_seconds = proto.Field(proto.BOOL, number=5) - decay_curve_stopping_spec = proto.Field(proto.MESSAGE, number=4, oneof='automated_stopping_spec', + decay_curve_stopping_spec = proto.Field( + proto.MESSAGE, + number=4, + oneof="automated_stopping_spec", message=DecayCurveAutomatedStoppingSpec, ) - median_automated_stopping_spec = proto.Field(proto.MESSAGE, number=5, oneof='automated_stopping_spec', + median_automated_stopping_spec = proto.Field( + proto.MESSAGE, + number=5, + oneof="automated_stopping_spec", message=MedianAutomatedStoppingSpec, ) - convex_stop_config = proto.Field(proto.MESSAGE, number=8, oneof='automated_stopping_spec', + convex_stop_config = proto.Field( + proto.MESSAGE, + number=8, + oneof="automated_stopping_spec", message=ConvexStopConfig, ) - metrics = proto.RepeatedField(proto.MESSAGE, number=1, - message=MetricSpec, - ) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) - parameters = proto.RepeatedField(proto.MESSAGE, number=2, - message=ParameterSpec, - ) + parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) - algorithm = proto.Field(proto.ENUM, number=3, - enum=Algorithm, - ) + algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) - observation_noise = proto.Field(proto.ENUM, number=6, - enum=ObservationNoise, - ) + observation_noise = proto.Field(proto.ENUM, number=6, enum=ObservationNoise,) - measurement_selection_type = proto.Field(proto.ENUM, number=7, - enum=MeasurementSelectionType, + measurement_selection_type = proto.Field( + proto.ENUM, number=7, enum=MeasurementSelectionType, ) @@ -610,6 +613,7 @@ class Measurement(proto.Message): evaluating the objective functions using suggested Parameter values. """ + class Metric(proto.Message): r"""A message representing a metric in the measurement. @@ -626,15 +630,11 @@ class Metric(proto.Message): value = proto.Field(proto.DOUBLE, number=2) - elapsed_duration = proto.Field(proto.MESSAGE, number=1, - message=duration.Duration, - ) + elapsed_duration = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) step_count = proto.Field(proto.INT64, number=2) - metrics = proto.RepeatedField(proto.MESSAGE, number=3, - message=Metric, - ) + metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index 84f1a7d2c6..3c03b0f47d 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -28,14 +28,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'TrainingPipeline', - 'InputDataConfig', - 'FractionSplit', - 'FilterSplit', - 'PredefinedSplit', - 'TimestampSplit', + "TrainingPipeline", + "InputDataConfig", + "FractionSplit", + "FilterSplit", + "PredefinedSplit", + "TimestampSplit", }, ) @@ -155,52 +155,32 @@ class TrainingPipeline(proto.Message): display_name = proto.Field(proto.STRING, number=2) - input_data_config = proto.Field(proto.MESSAGE, number=3, - message='InputDataConfig', - ) + input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) training_task_definition = proto.Field(proto.STRING, number=4) - training_task_inputs = proto.Field(proto.MESSAGE, number=5, - message=struct.Value, - ) + training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - training_task_metadata = proto.Field(proto.MESSAGE, number=6, - message=struct.Value, - ) + training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - model_to_upload = proto.Field(proto.MESSAGE, number=7, - message=model.Model, - ) + model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) - state = proto.Field(proto.ENUM, number=9, - enum=pipeline_state.PipelineState, - ) + state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) - error = proto.Field(proto.MESSAGE, number=10, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=15) - encryption_spec = proto.Field(proto.MESSAGE, number=18, - message=gca_encryption_spec.EncryptionSpec, + encryption_spec = proto.Field( + proto.MESSAGE, number=18, message=gca_encryption_spec.EncryptionSpec, ) @@ -321,28 +301,28 @@ class InputDataConfig(proto.Message): ``annotation_schema_uri``. """ - fraction_split = proto.Field(proto.MESSAGE, number=2, oneof='split', - message='FractionSplit', + fraction_split = proto.Field( + proto.MESSAGE, number=2, oneof="split", message="FractionSplit", ) - filter_split = proto.Field(proto.MESSAGE, number=3, oneof='split', - message='FilterSplit', + filter_split = proto.Field( + proto.MESSAGE, number=3, oneof="split", message="FilterSplit", ) - predefined_split = proto.Field(proto.MESSAGE, number=4, oneof='split', - message='PredefinedSplit', + predefined_split = proto.Field( + proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", ) - timestamp_split = proto.Field(proto.MESSAGE, number=5, oneof='split', - message='TimestampSplit', + timestamp_split = proto.Field( + proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", ) - gcs_destination = proto.Field(proto.MESSAGE, number=8, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, ) - bigquery_destination = proto.Field(proto.MESSAGE, number=10, oneof='destination', - message=io.BigQueryDestination, + bigquery_destination = proto.Field( + proto.MESSAGE, number=10, oneof="destination", message=io.BigQueryDestination, ) dataset_id = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index a2ff3629c0..9b5532b2e0 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'UserActionReference', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"UserActionReference",}, ) @@ -47,9 +44,9 @@ class UserActionReference(proto.Message): "/google.cloud.aiplatform.master.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1, oneof='reference') + operation = proto.Field(proto.STRING, number=1, oneof="reference") - data_labeling_job = proto.Field(proto.STRING, number=2, oneof='reference') + data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") method = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/vizier_service.py b/google/cloud/aiplatform_v1beta1/types/vizier_service.py index 9a7b4be68f..2b837c476e 100644 --- a/google/cloud/aiplatform_v1beta1/types/vizier_service.py +++ b/google/cloud/aiplatform_v1beta1/types/vizier_service.py @@ -24,30 +24,30 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'GetStudyRequest', - 'CreateStudyRequest', - 'ListStudiesRequest', - 'ListStudiesResponse', - 'DeleteStudyRequest', - 'LookupStudyRequest', - 'SuggestTrialsRequest', - 'SuggestTrialsResponse', - 'SuggestTrialsMetadata', - 'CreateTrialRequest', - 'GetTrialRequest', - 'ListTrialsRequest', - 'ListTrialsResponse', - 'AddTrialMeasurementRequest', - 'CompleteTrialRequest', - 'DeleteTrialRequest', - 'CheckTrialEarlyStoppingStateRequest', - 'CheckTrialEarlyStoppingStateResponse', - 'CheckTrialEarlyStoppingStateMetatdata', - 'StopTrialRequest', - 'ListOptimalTrialsRequest', - 'ListOptimalTrialsResponse', + "GetStudyRequest", + "CreateStudyRequest", + "ListStudiesRequest", + "ListStudiesResponse", + "DeleteStudyRequest", + "LookupStudyRequest", + "SuggestTrialsRequest", + "SuggestTrialsResponse", + "SuggestTrialsMetadata", + "CreateTrialRequest", + "GetTrialRequest", + "ListTrialsRequest", + "ListTrialsResponse", + "AddTrialMeasurementRequest", + "CompleteTrialRequest", + "DeleteTrialRequest", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CheckTrialEarlyStoppingStateMetatdata", + "StopTrialRequest", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", }, ) @@ -81,9 +81,7 @@ class CreateStudyRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - study = proto.Field(proto.MESSAGE, number=2, - message=gca_study.Study, - ) + study = proto.Field(proto.MESSAGE, number=2, message=gca_study.Study,) class ListStudiesRequest(proto.Message): @@ -129,9 +127,7 @@ class ListStudiesResponse(proto.Message): def raw_page(self): return self - studies = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_study.Study, - ) + studies = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Study,) next_page_token = proto.Field(proto.STRING, number=2) @@ -213,21 +209,13 @@ class SuggestTrialsResponse(proto.Message): completed. """ - trials = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_study.Trial, - ) + trials = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Trial,) - study_state = proto.Field(proto.ENUM, number=2, - enum=gca_study.Study.State, - ) + study_state = proto.Field(proto.ENUM, number=2, enum=gca_study.Study.State,) - start_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) class SuggestTrialsMetadata(proto.Message): @@ -246,8 +234,8 @@ class SuggestTrialsMetadata(proto.Message): Trial if the last suggested Trial was completed. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) client_id = proto.Field(proto.STRING, number=2) @@ -268,9 +256,7 @@ class CreateTrialRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - trial = proto.Field(proto.MESSAGE, number=2, - message=gca_study.Trial, - ) + trial = proto.Field(proto.MESSAGE, number=2, message=gca_study.Trial,) class GetTrialRequest(proto.Message): @@ -329,9 +315,7 @@ class ListTrialsResponse(proto.Message): def raw_page(self): return self - trials = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_study.Trial, - ) + trials = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Trial,) next_page_token = proto.Field(proto.STRING, number=2) @@ -351,9 +335,7 @@ class AddTrialMeasurementRequest(proto.Message): trial_name = proto.Field(proto.STRING, number=1) - measurement = proto.Field(proto.MESSAGE, number=3, - message=gca_study.Measurement, - ) + measurement = proto.Field(proto.MESSAGE, number=3, message=gca_study.Measurement,) class CompleteTrialRequest(proto.Message): @@ -380,8 +362,8 @@ class CompleteTrialRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - final_measurement = proto.Field(proto.MESSAGE, number=2, - message=gca_study.Measurement, + final_measurement = proto.Field( + proto.MESSAGE, number=2, message=gca_study.Measurement, ) trial_infeasible = proto.Field(proto.BOOL, number=3) @@ -442,8 +424,8 @@ class CheckTrialEarlyStoppingStateMetatdata(proto.Message): The Trial name. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) study = proto.Field(proto.STRING, number=2) @@ -489,8 +471,8 @@ class ListOptimalTrialsResponse(proto.Message): https://en.wikipedia.org/wiki/Pareto_efficiency """ - optimal_trials = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_study.Trial, + optimal_trials = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_study.Trial, ) diff --git a/noxfile.py b/noxfile.py index 32bd822f2b..8f3368c8dc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -27,9 +27,9 @@ BLACK_VERSION = "black==19.10b0" BLACK_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION="3.8" -SYSTEM_TEST_PYTHON_VERSIONS=["3.8"] -UNIT_TEST_PYTHON_VERSIONS=["3.6","3.7","3.8","3.9"] +DEFAULT_PYTHON_VERSION = "3.8" +SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] +UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -57,9 +57,7 @@ def lint(session): """ session.install("flake8", BLACK_VERSION) session.run( - "black", - "--check", - *BLACK_PATHS, + "black", "--check", *BLACK_PATHS, ) session.run("flake8", "google", "tests") @@ -76,8 +74,7 @@ def blacken(session): """ session.install(BLACK_VERSION) session.run( - "black", - *BLACK_PATHS, + "black", *BLACK_PATHS, ) @@ -95,12 +92,10 @@ def default(session): CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" ) session.install("asyncmock", "pytest-asyncio", "-c", constraints_path) - - session.install("mock", "pytest", "pytest-cov", "-c", constraints_path) - - + + session.install("mock", "pytest", "pytest-cov", "-c", constraints_path) + session.install("-e", ".", "-c", constraints_path) - # Run py.test against the unit tests. session.run( @@ -117,6 +112,7 @@ def default(session): *session.posargs, ) + @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) def unit(session): """Run the unit test suite.""" @@ -133,7 +129,7 @@ def system(session): system_test_folder_path = os.path.join("tests", "system") # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. - if os.environ.get("RUN_SYSTEM_TESTS", "true") == 'false': + if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": session.skip("RUN_SYSTEM_TESTS is set to false, skipping") # Sanity check: Only run tests if the environment variable is set. if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): @@ -155,7 +151,6 @@ def system(session): # virtualenv's dist-packages. session.install("mock", "pytest", "google-cloud-testutils", "-c", constraints_path) session.install("-e", ".", "-c", constraints_path) - # Run py.test against the system tests. if system_test_exists: @@ -164,7 +159,7 @@ def system(session): "--quiet", f"--junitxml=system_{session.python}_sponge_log.xml", system_test_path, - *session.posargs + *session.posargs, ) if system_test_folder_exists: session.run( @@ -172,11 +167,10 @@ def system(session): "--quiet", f"--junitxml=system_{session.python}_sponge_log.xml", system_test_folder_path, - *session.posargs + *session.posargs, ) - @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. @@ -189,23 +183,25 @@ def cover(session): session.run("coverage", "erase") + @nox.session(python=DEFAULT_PYTHON_VERSION) def docs(session): """Build the docs for this library.""" - session.install('-e', '.') - session.install('sphinx', 'alabaster', 'recommonmark') + session.install("-e", ".") + session.install("sphinx", "alabaster", "recommonmark") - shutil.rmtree(os.path.join('docs', '_build'), ignore_errors=True) + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( - 'sphinx-build', - - '-T', # show full traceback on exception - '-N', # no colors - '-b', 'html', - '-d', os.path.join('docs', '_build', 'doctrees', ''), - os.path.join('docs', ''), - os.path.join('docs', '_build', 'html', ''), + "sphinx-build", + "-T", # show full traceback on exception + "-N", # no colors + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), ) diff --git a/tests/unit/gapic/aiplatform_v1/__init__.py b/tests/unit/gapic/aiplatform_v1/__init__.py index 6a73015364..42ffdf2bc4 100644 --- a/tests/unit/gapic/aiplatform_v1/__init__.py +++ b/tests/unit/gapic/aiplatform_v1/__init__.py @@ -1,4 +1,3 @@ - # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py index 118d0eefe5..c59b335074 100644 --- a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.dataset_service import DatasetServiceAsyncClient +from google.cloud.aiplatform_v1.services.dataset_service import ( + DatasetServiceAsyncClient, +) from google.cloud.aiplatform_v1.services.dataset_service import DatasetServiceClient from google.cloud.aiplatform_v1.services.dataset_service import pagers from google.cloud.aiplatform_v1.services.dataset_service import transports @@ -63,7 +65,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -74,36 +80,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - DatasetServiceClient, - DatasetServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] +) def test_dataset_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - DatasetServiceClient, - DatasetServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] +) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -113,7 +135,7 @@ def test_dataset_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_client_get_transport_class(): @@ -127,29 +149,44 @@ def test_dataset_service_client_get_transport_class(): assert transport == transports.DatasetServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) -def test_dataset_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -165,7 +202,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -181,7 +218,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -201,13 +238,15 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -220,26 +259,52 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_dataset_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -262,10 +327,18 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -286,9 +359,14 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -302,16 +380,23 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -324,16 +409,24 @@ def test_dataset_service_client_client_options_scopes(client_class, transport_cl client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -348,10 +441,12 @@ def test_dataset_service_client_client_options_credentials_file(client_class, tr def test_dataset_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = DatasetServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -364,10 +459,11 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): +def test_create_dataset( + transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -375,11 +471,9 @@ def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.Cr request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_dataset(request) @@ -401,25 +495,24 @@ def test_create_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: client.create_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.CreateDatasetRequest() + @pytest.mark.asyncio -async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.CreateDatasetRequest): +async def test_create_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -427,12 +520,10 @@ async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_dataset(request) @@ -453,20 +544,16 @@ async def test_create_dataset_async_from_dict(): def test_create_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_dataset(request) @@ -477,28 +564,23 @@ def test_create_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_dataset(request) @@ -509,29 +591,21 @@ async def test_create_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -539,47 +613,40 @@ def test_create_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") def test_create_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) @pytest.mark.asyncio async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -587,31 +654,30 @@ async def test_create_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") @pytest.mark.asyncio async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) -def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): +def test_get_dataset( + transport: str = "grpc", request_type=dataset_service.GetDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -619,19 +685,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.get_dataset(request) @@ -646,13 +706,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_dataset_from_dict(): @@ -663,25 +723,24 @@ def test_get_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: client.get_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetDatasetRequest() + @pytest.mark.asyncio -async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetDatasetRequest): +async def test_get_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -689,16 +748,16 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.get_dataset(request) @@ -711,13 +770,13 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=d # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -726,19 +785,15 @@ async def test_get_dataset_async_from_dict(): def test_get_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = dataset.Dataset() client.get_dataset(request) @@ -750,27 +805,20 @@ def test_get_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) await client.get_dataset(request) @@ -782,99 +830,79 @@ async def test_get_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset( - name='name_value', - ) + client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset( - name='name_value', - ) + response = await client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) -def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): +def test_update_dataset( + transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -882,19 +910,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.update_dataset(request) @@ -909,13 +931,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_dataset_from_dict(): @@ -926,25 +948,24 @@ def test_update_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: client.update_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.UpdateDatasetRequest() + @pytest.mark.asyncio -async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.UpdateDatasetRequest): +async def test_update_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -952,16 +973,16 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.update_dataset(request) @@ -974,13 +995,13 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -989,19 +1010,15 @@ async def test_update_dataset_async_from_dict(): def test_update_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = gca_dataset.Dataset() client.update_dataset(request) @@ -1013,27 +1030,22 @@ def test_update_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) await client.update_dataset(request) @@ -1045,29 +1057,24 @@ async def test_update_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] def test_update_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1075,36 +1082,30 @@ def test_update_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() @@ -1112,8 +1113,8 @@ async def test_update_dataset_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1121,31 +1122,30 @@ async def test_update_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): +def test_list_datasets( + transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1153,13 +1153,10 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_datasets(request) @@ -1174,7 +1171,7 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_datasets_from_dict(): @@ -1185,25 +1182,24 @@ def test_list_datasets_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: client.list_datasets() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDatasetsRequest() + @pytest.mark.asyncio -async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDatasetsRequest): +async def test_list_datasets_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1211,13 +1207,13 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_datasets(request) @@ -1230,7 +1226,7 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1239,19 +1235,15 @@ async def test_list_datasets_async_from_dict(): def test_list_datasets_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: call.return_value = dataset_service.ListDatasetsResponse() client.list_datasets(request) @@ -1263,28 +1255,23 @@ def test_list_datasets_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) await client.list_datasets(request) @@ -1295,138 +1282,100 @@ async def test_list_datasets_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_datasets_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_datasets( - parent='parent_value', - ) + client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_datasets_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_datasets( - parent='parent_value', - ) + response = await client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) def test_list_datasets_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_datasets(request={}) @@ -1434,147 +1383,102 @@ def test_list_datasets_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in results) + assert all(isinstance(i, dataset.Dataset) for i in results) + def test_list_datasets_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in responses) + assert all(isinstance(i, dataset.Dataset) for i in responses) + @pytest.mark.asyncio async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_datasets(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): +def test_delete_dataset( + transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1582,11 +1486,9 @@ def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.De request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_dataset(request) @@ -1608,25 +1510,24 @@ def test_delete_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: client.delete_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.DeleteDatasetRequest() + @pytest.mark.asyncio -async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.DeleteDatasetRequest): +async def test_delete_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1634,12 +1535,10 @@ async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_dataset(request) @@ -1660,20 +1559,16 @@ async def test_delete_dataset_async_from_dict(): def test_delete_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_dataset(request) @@ -1684,28 +1579,23 @@ def test_delete_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_dataset(request) @@ -1716,101 +1606,81 @@ async def test_delete_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset( - name='name_value', - ) + client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset( - name='name_value', - ) + response = await client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) -def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): +def test_import_data( + transport: str = "grpc", request_type=dataset_service.ImportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1818,11 +1688,9 @@ def test_import_data(transport: str = 'grpc', request_type=dataset_service.Impor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.import_data(request) @@ -1844,25 +1712,24 @@ def test_import_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: client.import_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ImportDataRequest() + @pytest.mark.asyncio -async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ImportDataRequest): +async def test_import_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1870,12 +1737,10 @@ async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.import_data(request) @@ -1896,20 +1761,16 @@ async def test_import_data_async_from_dict(): def test_import_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.import_data(request) @@ -1920,28 +1781,23 @@ def test_import_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.import_data(request) @@ -1952,29 +1808,24 @@ async def test_import_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_import_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -1982,47 +1833,47 @@ def test_import_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] def test_import_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) @pytest.mark.asyncio async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -2030,31 +1881,34 @@ async def test_import_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] @pytest.mark.asyncio async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) -def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): +def test_export_data( + transport: str = "grpc", request_type=dataset_service.ExportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2062,11 +1916,9 @@ def test_export_data(transport: str = 'grpc', request_type=dataset_service.Expor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_data(request) @@ -2088,25 +1940,24 @@ def test_export_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: client.export_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ExportDataRequest() + @pytest.mark.asyncio -async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ExportDataRequest): +async def test_export_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2114,12 +1965,10 @@ async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_data(request) @@ -2140,20 +1989,16 @@ async def test_export_data_async_from_dict(): def test_export_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_data(request) @@ -2164,28 +2009,23 @@ def test_export_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_data(request) @@ -2196,29 +2036,26 @@ async def test_export_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_export_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2226,47 +2063,53 @@ def test_export_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) def test_export_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) @pytest.mark.asyncio async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2274,31 +2117,38 @@ async def test_export_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) @pytest.mark.asyncio async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) -def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): +def test_list_data_items( + transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2306,13 +2156,10 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_items(request) @@ -2327,7 +2174,7 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_items_from_dict(): @@ -2338,25 +2185,24 @@ def test_list_data_items_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: client.list_data_items() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDataItemsRequest() + @pytest.mark.asyncio -async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDataItemsRequest): +async def test_list_data_items_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2364,13 +2210,13 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_items(request) @@ -2383,7 +2229,7 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2392,19 +2238,15 @@ async def test_list_data_items_async_from_dict(): def test_list_data_items_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: call.return_value = dataset_service.ListDataItemsResponse() client.list_data_items(request) @@ -2416,28 +2258,23 @@ def test_list_data_items_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) await client.list_data_items(request) @@ -2448,104 +2285,81 @@ async def test_list_data_items_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_items_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items( - parent='parent_value', - ) + client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_items_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items( - parent='parent_value', - ) + response = await client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) def test_list_data_items_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2554,32 +2368,23 @@ def test_list_data_items_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_items(request={}) @@ -2587,18 +2392,14 @@ def test_list_data_items_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in results) + assert all(isinstance(i, data_item.DataItem) for i in results) + def test_list_data_items_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2607,40 +2408,32 @@ def test_list_data_items_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2649,46 +2442,37 @@ async def test_list_data_items_async_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in responses) + assert all(isinstance(i, data_item.DataItem) for i in responses) + @pytest.mark.asyncio async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2697,37 +2481,31 @@ async def test_list_data_items_async_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_data_items(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): +def test_get_annotation_spec( + transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2736,16 +2514,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - + name="name_value", display_name="display_name_value", etag="etag_value", ) response = client.get_annotation_spec(request) @@ -2760,11 +2533,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_annotation_spec_from_dict(): @@ -2775,25 +2548,27 @@ def test_get_annotation_spec_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: client.get_annotation_spec() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetAnnotationSpecRequest() + @pytest.mark.asyncio -async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetAnnotationSpecRequest): +async def test_get_annotation_spec_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.GetAnnotationSpecRequest, +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2802,14 +2577,14 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( - name='name_value', - display_name='display_name_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + ) response = await client.get_annotation_spec(request) @@ -2822,11 +2597,11 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -2835,19 +2610,17 @@ async def test_get_annotation_spec_async_from_dict(): def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: call.return_value = annotation_spec.AnnotationSpec() client.get_annotation_spec(request) @@ -2859,28 +2632,25 @@ def test_get_annotation_spec_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) await client.get_annotation_spec(request) @@ -2891,99 +2661,85 @@ async def test_get_annotation_spec_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_annotation_spec_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_annotation_spec( - name='name_value', - ) + client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_annotation_spec( - name='name_value', - ) + response = await client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) -def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): +def test_list_annotations( + transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2991,13 +2747,10 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_annotations(request) @@ -3012,7 +2765,7 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_annotations_from_dict(): @@ -3023,25 +2776,24 @@ def test_list_annotations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: client.list_annotations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListAnnotationsRequest() + @pytest.mark.asyncio -async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListAnnotationsRequest): +async def test_list_annotations_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3049,13 +2801,13 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_annotations(request) @@ -3068,7 +2820,7 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3077,19 +2829,15 @@ async def test_list_annotations_async_from_dict(): def test_list_annotations_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: call.return_value = dataset_service.ListAnnotationsResponse() client.list_annotations(request) @@ -3101,28 +2849,23 @@ def test_list_annotations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) await client.list_annotations(request) @@ -3133,104 +2876,81 @@ async def test_list_annotations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_annotations_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations( - parent='parent_value', - ) + client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_annotations_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations( - parent='parent_value', - ) + response = await client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) def test_list_annotations_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3239,32 +2959,23 @@ def test_list_annotations_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_annotations(request={}) @@ -3272,18 +2983,14 @@ def test_list_annotations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in results) + assert all(isinstance(i, annotation.Annotation) for i in results) + def test_list_annotations_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3292,40 +2999,32 @@ def test_list_annotations_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3334,46 +3033,37 @@ async def test_list_annotations_async_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in responses) + assert all(isinstance(i, annotation.Annotation) for i in responses) + @pytest.mark.asyncio async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3382,30 +3072,23 @@ async def test_list_annotations_async_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_annotations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3416,8 +3099,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3436,8 +3118,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3465,13 +3146,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3479,13 +3163,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.DatasetServiceGrpcTransport, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) def test_dataset_service_base_transport_error(): @@ -3493,13 +3172,15 @@ def test_dataset_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_dataset_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3508,17 +3189,17 @@ def test_dataset_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_dataset', - 'get_dataset', - 'update_dataset', - 'list_datasets', - 'delete_dataset', - 'import_data', - 'export_data', - 'list_data_items', - 'get_annotation_spec', - 'list_annotations', - ) + "create_dataset", + "get_dataset", + "update_dataset", + "list_datasets", + "delete_dataset", + "import_data", + "export_data", + "list_data_items", + "get_annotation_spec", + "list_annotations", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3531,23 +3212,28 @@ def test_dataset_service_base_transport(): def test_dataset_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_dataset_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport() @@ -3556,11 +3242,11 @@ def test_dataset_service_base_transport_with_adc(): def test_dataset_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) DatasetServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3568,19 +3254,25 @@ def test_dataset_service_auth_adc(): def test_dataset_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) -def test_dataset_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3589,15 +3281,13 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3612,38 +3302,40 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_host_with_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_dataset_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3651,12 +3343,11 @@ def test_dataset_service_grpc_transport_channel(): def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3665,12 +3356,22 @@ def test_dataset_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3679,7 +3380,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3695,9 +3396,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3711,17 +3410,23 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) -def test_dataset_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3738,9 +3443,7 @@ def test_dataset_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3753,16 +3456,12 @@ def test_dataset_service_transport_channel_mtls_with_adc( def test_dataset_service_grpc_lro_client(): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3770,16 +3469,12 @@ def test_dataset_service_grpc_lro_client(): def test_dataset_service_grpc_lro_async_client(): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3792,19 +3487,26 @@ def test_annotation_path(): data_item = "octopus" annotation = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) - actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + actual = DatasetServiceClient.annotation_path( + project, location, dataset, data_item, annotation + ) assert expected == actual def test_parse_annotation_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - "data_item": "winkle", - "annotation": "nautilus", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", } path = DatasetServiceClient.annotation_path(**expected) @@ -3812,24 +3514,31 @@ def test_parse_annotation_path(): actual = DatasetServiceClient.parse_annotation_path(path) assert expected == actual + def test_annotation_spec_path(): project = "scallop" location = "abalone" dataset = "squid" annotation_spec = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) - actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + actual = DatasetServiceClient.annotation_spec_path( + project, location, dataset, annotation_spec + ) assert expected == actual def test_parse_annotation_spec_path(): expected = { - "project": "whelk", - "location": "octopus", - "dataset": "oyster", - "annotation_spec": "nudibranch", - + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", } path = DatasetServiceClient.annotation_spec_path(**expected) @@ -3837,24 +3546,26 @@ def test_parse_annotation_spec_path(): actual = DatasetServiceClient.parse_annotation_spec_path(path) assert expected == actual + def test_data_item_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" data_item = "nautilus" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) assert expected == actual def test_parse_data_item_path(): expected = { - "project": "scallop", - "location": "abalone", - "dataset": "squid", - "data_item": "clam", - + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", } path = DatasetServiceClient.data_item_path(**expected) @@ -3862,22 +3573,24 @@ def test_parse_data_item_path(): actual = DatasetServiceClient.parse_data_item_path(path) assert expected == actual + def test_dataset_path(): project = "whelk" location = "octopus" dataset = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = DatasetServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", } path = DatasetServiceClient.dataset_path(**expected) @@ -3885,18 +3598,20 @@ def test_parse_dataset_path(): actual = DatasetServiceClient.parse_dataset_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = DatasetServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } path = DatasetServiceClient.common_billing_account_path(**expected) @@ -3904,18 +3619,18 @@ def test_parse_common_billing_account_path(): actual = DatasetServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = DatasetServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = DatasetServiceClient.common_folder_path(**expected) @@ -3923,18 +3638,18 @@ def test_parse_common_folder_path(): actual = DatasetServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = DatasetServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = DatasetServiceClient.common_organization_path(**expected) @@ -3942,18 +3657,18 @@ def test_parse_common_organization_path(): actual = DatasetServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = DatasetServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = DatasetServiceClient.common_project_path(**expected) @@ -3961,20 +3676,22 @@ def test_parse_common_project_path(): actual = DatasetServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = DatasetServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = DatasetServiceClient.common_location_path(**expected) @@ -3986,17 +3703,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = DatasetServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index b2ae6bd168..90d41c04c0 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.endpoint_service import EndpointServiceAsyncClient +from google.cloud.aiplatform_v1.services.endpoint_service import ( + EndpointServiceAsyncClient, +) from google.cloud.aiplatform_v1.services.endpoint_service import EndpointServiceClient from google.cloud.aiplatform_v1.services.endpoint_service import pagers from google.cloud.aiplatform_v1.services.endpoint_service import transports @@ -60,7 +62,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -71,36 +77,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - EndpointServiceClient, - EndpointServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] +) def test_endpoint_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - EndpointServiceClient, - EndpointServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] +) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -110,7 +132,7 @@ def test_endpoint_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_client_get_transport_class(): @@ -124,29 +146,44 @@ def test_endpoint_service_client_get_transport_class(): assert transport == transports.EndpointServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) -def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -162,7 +199,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -178,7 +215,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -198,13 +235,15 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -217,26 +256,62 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "true", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "false", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_endpoint_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -259,10 +334,18 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -283,9 +366,14 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -299,16 +387,23 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -321,16 +416,24 @@ def test_endpoint_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -345,10 +448,12 @@ def test_endpoint_service_client_client_options_credentials_file(client_class, t def test_endpoint_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = EndpointServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -361,10 +466,11 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): +def test_create_endpoint( + transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -372,11 +478,9 @@ def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_endpoint(request) @@ -398,25 +502,24 @@ def test_create_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: client.create_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.CreateEndpointRequest() + @pytest.mark.asyncio -async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.CreateEndpointRequest): +async def test_create_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -424,12 +527,10 @@ async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_endpoint(request) @@ -450,20 +551,16 @@ async def test_create_endpoint_async_from_dict(): def test_create_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_endpoint(request) @@ -474,28 +571,23 @@ def test_create_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_endpoint(request) @@ -506,29 +598,21 @@ async def test_create_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -536,47 +620,40 @@ def test_create_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") def test_create_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) @pytest.mark.asyncio async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -584,31 +661,30 @@ async def test_create_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") @pytest.mark.asyncio async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) -def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): +def test_get_endpoint( + transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -616,19 +692,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.get_endpoint(request) @@ -643,13 +713,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_endpoint_from_dict(): @@ -660,25 +730,24 @@ def test_get_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: client.get_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.GetEndpointRequest() + @pytest.mark.asyncio -async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.GetEndpointRequest): +async def test_get_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -686,16 +755,16 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.get_endpoint(request) @@ -708,13 +777,13 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -723,19 +792,15 @@ async def test_get_endpoint_async_from_dict(): def test_get_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = endpoint.Endpoint() client.get_endpoint(request) @@ -747,27 +812,20 @@ def test_get_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) await client.get_endpoint(request) @@ -779,99 +837,79 @@ async def test_get_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_endpoint( - name='name_value', - ) + client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_endpoint( - name='name_value', - ) + response = await client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) -def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): +def test_list_endpoints( + transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -879,13 +917,10 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_endpoints(request) @@ -900,7 +935,7 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_endpoints_from_dict(): @@ -911,25 +946,24 @@ def test_list_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: client.list_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.ListEndpointsRequest() + @pytest.mark.asyncio -async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.ListEndpointsRequest): +async def test_list_endpoints_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -937,13 +971,13 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_endpoints(request) @@ -956,7 +990,7 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -965,19 +999,15 @@ async def test_list_endpoints_async_from_dict(): def test_list_endpoints_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: call.return_value = endpoint_service.ListEndpointsResponse() client.list_endpoints(request) @@ -989,28 +1019,23 @@ def test_list_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) await client.list_endpoints(request) @@ -1021,104 +1046,81 @@ async def test_list_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_endpoints_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_endpoints( - parent='parent_value', - ) + client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_endpoints_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_endpoints( - parent='parent_value', - ) + response = await client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) def test_list_endpoints_pager(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1127,32 +1129,23 @@ def test_list_endpoints_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_endpoints(request={}) @@ -1160,18 +1153,14 @@ def test_list_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in results) + assert all(isinstance(i, endpoint.Endpoint) for i in results) + def test_list_endpoints_pages(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1180,40 +1169,32 @@ def test_list_endpoints_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1222,46 +1203,37 @@ async def test_list_endpoints_async_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in responses) + assert all(isinstance(i, endpoint.Endpoint) for i in responses) + @pytest.mark.asyncio async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1270,37 +1242,31 @@ async def test_list_endpoints_async_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): +def test_update_endpoint( + transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1308,19 +1274,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.update_endpoint(request) @@ -1335,13 +1295,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_endpoint_from_dict(): @@ -1352,25 +1312,24 @@ def test_update_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: client.update_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UpdateEndpointRequest() + @pytest.mark.asyncio -async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UpdateEndpointRequest): +async def test_update_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1378,16 +1337,16 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.update_endpoint(request) @@ -1400,13 +1359,13 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -1415,19 +1374,15 @@ async def test_update_endpoint_async_from_dict(): def test_update_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: call.return_value = gca_endpoint.Endpoint() client.update_endpoint(request) @@ -1439,28 +1394,25 @@ def test_update_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) await client.update_endpoint(request) @@ -1471,29 +1423,24 @@ async def test_update_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] def test_update_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1501,45 +1448,41 @@ def test_update_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1547,31 +1490,30 @@ async def test_update_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): +def test_delete_endpoint( + transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1579,11 +1521,9 @@ def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_endpoint(request) @@ -1605,25 +1545,24 @@ def test_delete_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: client.delete_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeleteEndpointRequest() + @pytest.mark.asyncio -async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeleteEndpointRequest): +async def test_delete_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1631,12 +1570,10 @@ async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_endpoint(request) @@ -1657,20 +1594,16 @@ async def test_delete_endpoint_async_from_dict(): def test_delete_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_endpoint(request) @@ -1681,28 +1614,23 @@ def test_delete_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_endpoint(request) @@ -1713,101 +1641,81 @@ async def test_delete_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_endpoint( - name='name_value', - ) + client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_endpoint( - name='name_value', - ) + response = await client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) -def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): +def test_deploy_model( + transport: str = "grpc", request_type=endpoint_service.DeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1815,11 +1723,9 @@ def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.Dep request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.deploy_model(request) @@ -1841,25 +1747,24 @@ def test_deploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: client.deploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeployModelRequest() + @pytest.mark.asyncio -async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeployModelRequest): +async def test_deploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1867,12 +1772,10 @@ async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.deploy_model(request) @@ -1893,20 +1796,16 @@ async def test_deploy_model_async_from_dict(): def test_deploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.deploy_model(request) @@ -1917,28 +1816,23 @@ def test_deploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.deploy_model(request) @@ -1949,30 +1843,29 @@ async def test_deploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_deploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -1980,51 +1873,63 @@ def test_deploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_deploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2032,34 +1937,45 @@ async def test_deploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) -def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): +def test_undeploy_model( + transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2067,11 +1983,9 @@ def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.U request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.undeploy_model(request) @@ -2093,25 +2007,24 @@ def test_undeploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: client.undeploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UndeployModelRequest() + @pytest.mark.asyncio -async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UndeployModelRequest): +async def test_undeploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2119,12 +2032,10 @@ async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.undeploy_model(request) @@ -2145,20 +2056,16 @@ async def test_undeploy_model_async_from_dict(): def test_undeploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.undeploy_model(request) @@ -2169,28 +2076,23 @@ def test_undeploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.undeploy_model(request) @@ -2201,30 +2103,23 @@ async def test_undeploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_undeploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2232,51 +2127,45 @@ def test_undeploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_undeploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2284,27 +2173,25 @@ async def test_undeploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @@ -2315,8 +2202,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2335,8 +2221,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -2364,13 +2249,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2378,13 +2266,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.EndpointServiceGrpcTransport, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) def test_endpoint_service_base_transport_error(): @@ -2392,13 +2275,15 @@ def test_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2407,14 +2292,14 @@ def test_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_endpoint', - 'get_endpoint', - 'list_endpoints', - 'update_endpoint', - 'delete_endpoint', - 'deploy_model', - 'undeploy_model', - ) + "create_endpoint", + "get_endpoint", + "list_endpoints", + "update_endpoint", + "delete_endpoint", + "deploy_model", + "undeploy_model", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2427,23 +2312,28 @@ def test_endpoint_service_base_transport(): def test_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport() @@ -2452,11 +2342,11 @@ def test_endpoint_service_base_transport_with_adc(): def test_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) EndpointServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -2464,19 +2354,25 @@ def test_endpoint_service_auth_adc(): def test_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) -def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -2485,15 +2381,13 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2508,38 +2402,40 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_host_with_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_endpoint_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2547,12 +2443,11 @@ def test_endpoint_service_grpc_transport_channel(): def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2561,12 +2456,22 @@ def test_endpoint_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2575,7 +2480,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2591,9 +2496,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2607,17 +2510,23 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) -def test_endpoint_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2634,9 +2543,7 @@ def test_endpoint_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2649,16 +2556,12 @@ def test_endpoint_service_transport_channel_mtls_with_adc( def test_endpoint_service_grpc_lro_client(): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2666,16 +2569,12 @@ def test_endpoint_service_grpc_lro_client(): def test_endpoint_service_grpc_lro_async_client(): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2686,17 +2585,18 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = EndpointServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = EndpointServiceClient.endpoint_path(**expected) @@ -2704,22 +2604,24 @@ def test_parse_endpoint_path(): actual = EndpointServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = EndpointServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = EndpointServiceClient.model_path(**expected) @@ -2727,18 +2629,20 @@ def test_parse_model_path(): actual = EndpointServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = EndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2746,18 +2650,18 @@ def test_parse_common_billing_account_path(): actual = EndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = EndpointServiceClient.common_folder_path(**expected) @@ -2765,18 +2669,18 @@ def test_parse_common_folder_path(): actual = EndpointServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = EndpointServiceClient.common_organization_path(**expected) @@ -2784,18 +2688,18 @@ def test_parse_common_organization_path(): actual = EndpointServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = EndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = EndpointServiceClient.common_project_path(**expected) @@ -2803,20 +2707,22 @@ def test_parse_common_project_path(): actual = EndpointServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = EndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = EndpointServiceClient.common_location_path(**expected) @@ -2828,17 +2734,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = EndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index c6acd32ec8..ea8d1d502b 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -41,7 +41,9 @@ from google.cloud.aiplatform_v1.services.job_service import transports from google.cloud.aiplatform_v1.types import accelerator_type from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import completion_stats from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job @@ -50,7 +52,9 @@ from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import env_var from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import job_service from google.cloud.aiplatform_v1.types import job_state @@ -77,7 +81,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -88,36 +96,45 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [ - JobServiceClient, - JobServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) def test_job_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - JobServiceClient, - JobServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -127,7 +144,7 @@ def test_job_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_client_get_transport_class(): @@ -141,29 +158,42 @@ def test_job_service_client_get_transport_class(): assert transport == transports.JobServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) -def test_job_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -179,7 +209,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -195,7 +225,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -215,13 +245,15 @@ def test_job_service_client_client_options(client_class, transport_class, transp client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -234,26 +266,50 @@ def test_job_service_client_client_options(client_class, transport_class, transp client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_job_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -276,10 +332,18 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -300,9 +364,14 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -316,16 +385,23 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -338,16 +414,24 @@ def test_job_service_client_client_options_scopes(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -362,11 +446,11 @@ def test_job_service_client_client_options_credentials_file(client_class, transp def test_job_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = JobServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -378,10 +462,11 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): +def test_create_custom_job( + transport: str = "grpc", request_type=job_service.CreateCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -390,16 +475,13 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_custom_job(request) @@ -414,9 +496,9 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -429,25 +511,26 @@ def test_create_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: client.create_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateCustomJobRequest() + @pytest.mark.asyncio -async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateCustomJobRequest): +async def test_create_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -456,14 +539,16 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_custom_job(request) @@ -476,9 +561,9 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_ # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -489,19 +574,17 @@ async def test_create_custom_job_async_from_dict(): def test_create_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: call.return_value = gca_custom_job.CustomJob() client.create_custom_job(request) @@ -513,28 +596,25 @@ def test_create_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) await client.create_custom_job(request) @@ -545,29 +625,24 @@ async def test_create_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -575,45 +650,43 @@ def test_create_custom_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") def test_create_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -621,31 +694,30 @@ async def test_create_custom_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") @pytest.mark.asyncio async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) -def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): +def test_get_custom_job( + transport: str = "grpc", request_type=job_service.GetCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -653,17 +725,12 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_custom_job(request) @@ -678,9 +745,9 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -693,25 +760,24 @@ def test_get_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: client.get_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetCustomJobRequest() + @pytest.mark.asyncio -async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetCustomJobRequest): +async def test_get_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -719,15 +785,15 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_custom_job(request) @@ -740,9 +806,9 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -753,19 +819,15 @@ async def test_get_custom_job_async_from_dict(): def test_get_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: call.return_value = custom_job.CustomJob() client.get_custom_job(request) @@ -777,28 +839,23 @@ def test_get_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) await client.get_custom_job(request) @@ -809,99 +866,81 @@ async def test_get_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_custom_job( - name='name_value', - ) + client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_custom_job( - name='name_value', - ) + response = await client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) -def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): +def test_list_custom_jobs( + transport: str = "grpc", request_type=job_service.ListCustomJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -909,13 +948,10 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_custom_jobs(request) @@ -930,7 +966,7 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_custom_jobs_from_dict(): @@ -941,25 +977,24 @@ def test_list_custom_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: client.list_custom_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListCustomJobsRequest() + @pytest.mark.asyncio -async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListCustomJobsRequest): +async def test_list_custom_jobs_async( + transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -967,13 +1002,11 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_custom_jobs(request) @@ -986,7 +1019,7 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -995,19 +1028,15 @@ async def test_list_custom_jobs_async_from_dict(): def test_list_custom_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: call.return_value = job_service.ListCustomJobsResponse() client.list_custom_jobs(request) @@ -1019,28 +1048,23 @@ def test_list_custom_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) await client.list_custom_jobs(request) @@ -1051,104 +1075,81 @@ async def test_list_custom_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_custom_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_custom_jobs( - parent='parent_value', - ) + client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_custom_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_custom_jobs( - parent='parent_value', - ) + response = await client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) def test_list_custom_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1157,32 +1158,21 @@ def test_list_custom_jobs_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_custom_jobs(request={}) @@ -1190,18 +1180,14 @@ def test_list_custom_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in results) + assert all(isinstance(i, custom_job.CustomJob) for i in results) + def test_list_custom_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1210,40 +1196,30 @@ def test_list_custom_jobs_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1252,46 +1228,35 @@ async def test_list_custom_jobs_async_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in responses) + assert all(isinstance(i, custom_job.CustomJob) for i in responses) + @pytest.mark.asyncio async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1300,37 +1265,29 @@ async def test_list_custom_jobs_async_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_custom_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): +def test_delete_custom_job( + transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1339,10 +1296,10 @@ def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.Del # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_custom_job(request) @@ -1364,25 +1321,26 @@ def test_delete_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: client.delete_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteCustomJobRequest() + @pytest.mark.asyncio -async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteCustomJobRequest): +async def test_delete_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1391,11 +1349,11 @@ async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_custom_job(request) @@ -1416,20 +1374,18 @@ async def test_delete_custom_job_async_from_dict(): def test_delete_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_custom_job(request) @@ -1440,28 +1396,25 @@ def test_delete_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_custom_job(request) @@ -1472,101 +1425,85 @@ async def test_delete_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_custom_job( - name='name_value', - ) + client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_custom_job( - name='name_value', - ) + response = await client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) -def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): +def test_cancel_custom_job( + transport: str = "grpc", request_type=job_service.CancelCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1575,8 +1512,8 @@ def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.Can # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1600,25 +1537,26 @@ def test_cancel_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: client.cancel_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelCustomJobRequest() + @pytest.mark.asyncio -async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelCustomJobRequest): +async def test_cancel_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1627,8 +1565,8 @@ async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1650,19 +1588,17 @@ async def test_cancel_custom_job_async_from_dict(): def test_cancel_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = None client.cancel_custom_job(request) @@ -1674,27 +1610,22 @@ def test_cancel_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_custom_job(request) @@ -1706,99 +1637,83 @@ async def test_cancel_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_custom_job( - name='name_value', - ) + client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_custom_job( - name='name_value', - ) + response = await client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) -def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): +def test_create_data_labeling_job( + transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1807,28 +1722,19 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.create_data_labeling_job(request) @@ -1843,23 +1749,23 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_create_data_labeling_job_from_dict(): @@ -1870,25 +1776,27 @@ def test_create_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: client.create_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateDataLabelingJobRequest): +async def test_create_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1897,20 +1805,22 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.create_data_labeling_job(request) @@ -1923,23 +1833,23 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] @pytest.mark.asyncio @@ -1948,19 +1858,17 @@ async def test_create_data_labeling_job_async_from_dict(): def test_create_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: call.return_value = gca_data_labeling_job.DataLabelingJob() client.create_data_labeling_job(request) @@ -1972,28 +1880,25 @@ def test_create_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) await client.create_data_labeling_job(request) @@ -2004,29 +1909,24 @@ async def test_create_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -2034,45 +1934,45 @@ def test_create_data_labeling_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -2080,31 +1980,32 @@ async def test_create_data_labeling_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) -def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): +def test_get_data_labeling_job( + transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2113,28 +2014,19 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.get_data_labeling_job(request) @@ -2149,23 +2041,23 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_get_data_labeling_job_from_dict(): @@ -2176,25 +2068,26 @@ def test_get_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: client.get_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetDataLabelingJobRequest): +async def test_get_data_labeling_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2203,20 +2096,22 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.get_data_labeling_job(request) @@ -2229,23 +2124,23 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] @pytest.mark.asyncio @@ -2254,19 +2149,17 @@ async def test_get_data_labeling_job_async_from_dict(): def test_get_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: call.return_value = data_labeling_job.DataLabelingJob() client.get_data_labeling_job(request) @@ -2278,28 +2171,25 @@ def test_get_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) await client.get_data_labeling_job(request) @@ -2310,99 +2200,85 @@ async def test_get_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_data_labeling_job( - name='name_value', - ) + client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_data_labeling_job( - name='name_value', - ) + response = await client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) -def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): +def test_list_data_labeling_jobs( + transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2411,12 +2287,11 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_labeling_jobs(request) @@ -2431,7 +2306,7 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_labeling_jobs_from_dict(): @@ -2442,25 +2317,27 @@ def test_list_data_labeling_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: client.list_data_labeling_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListDataLabelingJobsRequest() + @pytest.mark.asyncio -async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListDataLabelingJobsRequest): +async def test_list_data_labeling_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListDataLabelingJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2469,12 +2346,14 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_labeling_jobs(request) @@ -2487,7 +2366,7 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', re # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2496,19 +2375,17 @@ async def test_list_data_labeling_jobs_async_from_dict(): def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: call.return_value = job_service.ListDataLabelingJobsResponse() client.list_data_labeling_jobs(request) @@ -2520,28 +2397,25 @@ def test_list_data_labeling_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) await client.list_data_labeling_jobs(request) @@ -2552,104 +2426,87 @@ async def test_list_data_labeling_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_labeling_jobs( - parent='parent_value', - ) + client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs( - parent='parent_value', - ) + response = await client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) def test_list_data_labeling_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2658,17 +2515,14 @@ def test_list_data_labeling_jobs_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2681,9 +2535,7 @@ def test_list_data_labeling_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_labeling_jobs(request={}) @@ -2691,18 +2543,16 @@ def test_list_data_labeling_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in results) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) + def test_list_data_labeling_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2711,17 +2561,14 @@ def test_list_data_labeling_jobs_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2732,19 +2579,20 @@ def test_list_data_labeling_jobs_pages(): RuntimeError, ) pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2753,17 +2601,14 @@ async def test_list_data_labeling_jobs_async_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2774,25 +2619,25 @@ async def test_list_data_labeling_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in responses) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2801,17 +2646,14 @@ async def test_list_data_labeling_jobs_async_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2824,14 +2666,15 @@ async def test_list_data_labeling_jobs_async_pages(): pages = [] async for page_ in (await client.list_data_labeling_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): +def test_delete_data_labeling_job( + transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2840,10 +2683,10 @@ def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_data_labeling_job(request) @@ -2865,25 +2708,27 @@ def test_delete_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: client.delete_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteDataLabelingJobRequest): +async def test_delete_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2892,11 +2737,11 @@ async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_data_labeling_job(request) @@ -2917,20 +2762,18 @@ async def test_delete_data_labeling_job_async_from_dict(): def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_data_labeling_job(request) @@ -2941,28 +2784,25 @@ def test_delete_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_data_labeling_job(request) @@ -2973,101 +2813,85 @@ async def test_delete_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_data_labeling_job( - name='name_value', - ) + client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_data_labeling_job( - name='name_value', - ) + response = await client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" -@pytest.mark.asyncio -async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) +@pytest.mark.asyncio +async def test_delete_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) -def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): +def test_cancel_data_labeling_job( + transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3076,8 +2900,8 @@ def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -3101,25 +2925,27 @@ def test_cancel_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: client.cancel_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelDataLabelingJobRequest): +async def test_cancel_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3128,8 +2954,8 @@ async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -3151,19 +2977,17 @@ async def test_cancel_data_labeling_job_async_from_dict(): def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = None client.cancel_data_labeling_job(request) @@ -3175,27 +2999,22 @@ def test_cancel_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_data_labeling_job(request) @@ -3207,99 +3026,84 @@ async def test_cancel_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_data_labeling_job( - name='name_value', - ) + client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job( - name='name_value', - ) + response = await client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) -def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): +def test_create_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3308,22 +3112,16 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_hyperparameter_tuning_job(request) @@ -3338,9 +3136,9 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3359,25 +3157,27 @@ def test_create_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: client.create_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateHyperparameterTuningJobRequest): +async def test_create_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3386,17 +3186,19 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_hyperparameter_tuning_job(request) @@ -3409,9 +3211,9 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3428,19 +3230,17 @@ async def test_create_hyperparameter_tuning_job_async_from_dict(): def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() client.create_hyperparameter_tuning_job(request) @@ -3452,28 +3252,25 @@ def test_create_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.create_hyperparameter_tuning_job(request) @@ -3484,29 +3281,26 @@ async def test_create_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3514,45 +3308,51 @@ def test_create_hyperparameter_tuning_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3560,31 +3360,36 @@ async def test_create_hyperparameter_tuning_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) -def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): +def test_get_hyperparameter_tuning_job( + transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3593,22 +3398,16 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_hyperparameter_tuning_job(request) @@ -3623,9 +3422,9 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3644,25 +3443,27 @@ def test_get_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: client.get_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetHyperparameterTuningJobRequest): +async def test_get_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3671,17 +3472,19 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_hyperparameter_tuning_job(request) @@ -3694,9 +3497,9 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3713,19 +3516,17 @@ async def test_get_hyperparameter_tuning_job_async_from_dict(): def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() client.get_hyperparameter_tuning_job(request) @@ -3737,28 +3538,25 @@ def test_get_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.get_hyperparameter_tuning_job(request) @@ -3769,99 +3567,86 @@ async def test_get_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job( - name='name_value', - ) + client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) -def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): +def test_list_hyperparameter_tuning_jobs( + transport: str = "grpc", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3870,12 +3655,11 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_hyperparameter_tuning_jobs(request) @@ -3890,7 +3674,7 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_hyperparameter_tuning_jobs_from_dict(): @@ -3901,25 +3685,27 @@ def test_list_hyperparameter_tuning_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: client.list_hyperparameter_tuning_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListHyperparameterTuningJobsRequest): +async def test_list_hyperparameter_tuning_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3928,12 +3714,14 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_hyperparameter_tuning_jobs(request) @@ -3946,7 +3734,7 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3955,19 +3743,17 @@ async def test_list_hyperparameter_tuning_jobs_async_from_dict(): def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: call.return_value = job_service.ListHyperparameterTuningJobsResponse() client.list_hyperparameter_tuning_jobs(request) @@ -3979,28 +3765,25 @@ def test_list_hyperparameter_tuning_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) await client.list_hyperparameter_tuning_jobs(request) @@ -4011,104 +3794,87 @@ async def test_list_hyperparameter_tuning_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4117,17 +3883,16 @@ def test_list_hyperparameter_tuning_jobs_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4140,9 +3905,7 @@ def test_list_hyperparameter_tuning_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_hyperparameter_tuning_jobs(request={}) @@ -4150,18 +3913,19 @@ def test_list_hyperparameter_tuning_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results + ) + def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4170,17 +3934,16 @@ def test_list_hyperparameter_tuning_jobs_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4191,19 +3954,20 @@ def test_list_hyperparameter_tuning_jobs_pages(): RuntimeError, ) pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4212,17 +3976,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4233,25 +3996,28 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses + ) + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4260,17 +4026,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4281,16 +4046,20 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: + async for page_ in ( + await client.list_hyperparameter_tuning_jobs(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): +def test_delete_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4299,10 +4068,10 @@ def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_hyperparameter_tuning_job(request) @@ -4324,25 +4093,27 @@ def test_delete_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: client.delete_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteHyperparameterTuningJobRequest): +async def test_delete_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4351,11 +4122,11 @@ async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_hyperparameter_tuning_job(request) @@ -4376,20 +4147,18 @@ async def test_delete_hyperparameter_tuning_job_async_from_dict(): def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_hyperparameter_tuning_job(request) @@ -4400,28 +4169,25 @@ def test_delete_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_hyperparameter_tuning_job(request) @@ -4432,101 +4198,86 @@ async def test_delete_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job( - name='name_value', - ) + client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) -def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): +def test_cancel_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4535,8 +4286,8 @@ def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -4560,25 +4311,27 @@ def test_cancel_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: client.cancel_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelHyperparameterTuningJobRequest): +async def test_cancel_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4587,8 +4340,8 @@ async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -4610,19 +4363,17 @@ async def test_cancel_hyperparameter_tuning_job_async_from_dict(): def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = None client.cancel_hyperparameter_tuning_job(request) @@ -4634,27 +4385,22 @@ def test_cancel_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_hyperparameter_tuning_job(request) @@ -4666,99 +4412,83 @@ async def test_cancel_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) -def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): +def test_create_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4767,18 +4497,14 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_batch_prediction_job(request) @@ -4793,11 +4519,11 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -4810,25 +4536,27 @@ def test_create_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: client.create_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateBatchPredictionJobRequest): +async def test_create_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4837,15 +4565,17 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_batch_prediction_job(request) @@ -4858,11 +4588,11 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -4872,20 +4602,18 @@ async def test_create_batch_prediction_job_async_from_dict(): await test_create_batch_prediction_job_async(request_type=dict) -def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - +def test_create_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: call.return_value = gca_batch_prediction_job.BatchPredictionJob() client.create_batch_prediction_job(request) @@ -4897,28 +4625,25 @@ def test_create_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) await client.create_batch_prediction_job(request) @@ -4929,29 +4654,26 @@ async def test_create_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -4959,45 +4681,51 @@ def test_create_batch_prediction_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -5005,31 +4733,36 @@ async def test_create_batch_prediction_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) -def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): +def test_get_batch_prediction_job( + transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5038,18 +4771,14 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_batch_prediction_job(request) @@ -5064,11 +4793,11 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -5081,25 +4810,27 @@ def test_get_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: client.get_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetBatchPredictionJobRequest): +async def test_get_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5108,15 +4839,17 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_batch_prediction_job(request) @@ -5129,11 +4862,11 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -5144,19 +4877,17 @@ async def test_get_batch_prediction_job_async_from_dict(): def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: call.return_value = batch_prediction_job.BatchPredictionJob() client.get_batch_prediction_job(request) @@ -5168,28 +4899,25 @@ def test_get_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) await client.get_batch_prediction_job(request) @@ -5200,99 +4928,85 @@ async def test_get_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_batch_prediction_job( - name='name_value', - ) + client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_batch_prediction_job( - name='name_value', - ) + response = await client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) -def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): +def test_list_batch_prediction_jobs( + transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5301,12 +5015,11 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_batch_prediction_jobs(request) @@ -5321,7 +5034,7 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_batch_prediction_jobs_from_dict(): @@ -5332,25 +5045,27 @@ def test_list_batch_prediction_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: client.list_batch_prediction_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListBatchPredictionJobsRequest() + @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListBatchPredictionJobsRequest): +async def test_list_batch_prediction_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListBatchPredictionJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5359,12 +5074,14 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_batch_prediction_jobs(request) @@ -5377,7 +5094,7 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -5386,19 +5103,17 @@ async def test_list_batch_prediction_jobs_async_from_dict(): def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: call.return_value = job_service.ListBatchPredictionJobsResponse() client.list_batch_prediction_jobs(request) @@ -5410,28 +5125,25 @@ def test_list_batch_prediction_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) await client.list_batch_prediction_jobs(request) @@ -5442,104 +5154,87 @@ async def test_list_batch_prediction_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_batch_prediction_jobs( - parent='parent_value', - ) + client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs( - parent='parent_value', - ) + response = await client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5548,17 +5243,14 @@ def test_list_batch_prediction_jobs_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5571,9 +5263,7 @@ def test_list_batch_prediction_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_batch_prediction_jobs(request={}) @@ -5581,18 +5271,18 @@ def test_list_batch_prediction_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in results) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results + ) + def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5601,17 +5291,14 @@ def test_list_batch_prediction_jobs_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5622,19 +5309,20 @@ def test_list_batch_prediction_jobs_pages(): RuntimeError, ) pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5643,17 +5331,14 @@ async def test_list_batch_prediction_jobs_async_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5664,25 +5349,27 @@ async def test_list_batch_prediction_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in responses) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses + ) + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5691,17 +5378,14 @@ async def test_list_batch_prediction_jobs_async_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5714,14 +5398,15 @@ async def test_list_batch_prediction_jobs_async_pages(): pages = [] async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): +def test_delete_batch_prediction_job( + transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5730,10 +5415,10 @@ def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_batch_prediction_job(request) @@ -5755,25 +5440,27 @@ def test_delete_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: client.delete_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteBatchPredictionJobRequest): +async def test_delete_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5782,11 +5469,11 @@ async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_batch_prediction_job(request) @@ -5807,20 +5494,18 @@ async def test_delete_batch_prediction_job_async_from_dict(): def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_batch_prediction_job(request) @@ -5831,28 +5516,25 @@ def test_delete_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_batch_prediction_job(request) @@ -5863,101 +5545,85 @@ async def test_delete_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_batch_prediction_job( - name='name_value', - ) + client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job( - name='name_value', - ) + response = await client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) -def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): +def test_cancel_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5966,8 +5632,8 @@ def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -5991,25 +5657,27 @@ def test_cancel_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: client.cancel_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelBatchPredictionJobRequest): +async def test_cancel_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6018,8 +5686,8 @@ async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -6041,19 +5709,17 @@ async def test_cancel_batch_prediction_job_async_from_dict(): def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = None client.cancel_batch_prediction_job(request) @@ -6065,27 +5731,22 @@ def test_cancel_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_batch_prediction_job(request) @@ -6097,92 +5758,75 @@ async def test_cancel_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_batch_prediction_job( - name='name_value', - ) + client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job( - name='name_value', - ) + response = await client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) @@ -6193,8 +5837,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -6213,8 +5856,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -6242,13 +5884,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.JobServiceGrpcTransport, - transports.JobServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport,], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -6256,13 +5898,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.JobServiceGrpcTransport, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.JobServiceGrpcTransport,) def test_job_service_base_transport_error(): @@ -6270,13 +5907,15 @@ def test_job_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_job_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -6285,27 +5924,27 @@ def test_job_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_custom_job', - 'get_custom_job', - 'list_custom_jobs', - 'delete_custom_job', - 'cancel_custom_job', - 'create_data_labeling_job', - 'get_data_labeling_job', - 'list_data_labeling_jobs', - 'delete_data_labeling_job', - 'cancel_data_labeling_job', - 'create_hyperparameter_tuning_job', - 'get_hyperparameter_tuning_job', - 'list_hyperparameter_tuning_jobs', - 'delete_hyperparameter_tuning_job', - 'cancel_hyperparameter_tuning_job', - 'create_batch_prediction_job', - 'get_batch_prediction_job', - 'list_batch_prediction_jobs', - 'delete_batch_prediction_job', - 'cancel_batch_prediction_job', - ) + "create_custom_job", + "get_custom_job", + "list_custom_jobs", + "delete_custom_job", + "cancel_custom_job", + "create_data_labeling_job", + "get_data_labeling_job", + "list_data_labeling_jobs", + "delete_data_labeling_job", + "cancel_data_labeling_job", + "create_hyperparameter_tuning_job", + "get_hyperparameter_tuning_job", + "list_hyperparameter_tuning_jobs", + "delete_hyperparameter_tuning_job", + "cancel_hyperparameter_tuning_job", + "create_batch_prediction_job", + "get_batch_prediction_job", + "list_batch_prediction_jobs", + "delete_batch_prediction_job", + "cancel_batch_prediction_job", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -6318,23 +5957,28 @@ def test_job_service_base_transport(): def test_job_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_job_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport() @@ -6343,11 +5987,11 @@ def test_job_service_base_transport_with_adc(): def test_job_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) JobServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -6355,19 +5999,22 @@ def test_job_service_auth_adc(): def test_job_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.JobServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -6376,15 +6023,13 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -6399,38 +6044,40 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_host_with_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_job_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6438,12 +6085,11 @@ def test_job_service_grpc_transport_channel(): def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6452,12 +6098,17 @@ def test_job_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -6466,7 +6117,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -6482,9 +6133,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6498,17 +6147,20 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -6525,9 +6177,7 @@ def test_job_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6540,16 +6190,12 @@ def test_job_service_transport_channel_mtls_with_adc( def test_job_service_grpc_lro_client(): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6557,16 +6203,12 @@ def test_job_service_grpc_lro_client(): def test_job_service_grpc_lro_async_client(): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6577,17 +6219,20 @@ def test_batch_prediction_job_path(): location = "clam" batch_prediction_job = "whelk" - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) - actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, location=location, batch_prediction_job=batch_prediction_job, + ) + actual = JobServiceClient.batch_prediction_job_path( + project, location, batch_prediction_job + ) assert expected == actual def test_parse_batch_prediction_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", } path = JobServiceClient.batch_prediction_job_path(**expected) @@ -6595,22 +6240,24 @@ def test_parse_batch_prediction_job_path(): actual = JobServiceClient.parse_batch_prediction_job_path(path) assert expected == actual + def test_custom_job_path(): project = "cuttlefish" location = "mussel" custom_job = "winkle" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) actual = JobServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", - + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", } path = JobServiceClient.custom_job_path(**expected) @@ -6618,22 +6265,26 @@ def test_parse_custom_job_path(): actual = JobServiceClient.parse_custom_job_path(path) assert expected == actual + def test_data_labeling_job_path(): project = "squid" location = "clam" data_labeling_job = "whelk" - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) - actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + actual = JobServiceClient.data_labeling_job_path( + project, location, data_labeling_job + ) assert expected == actual def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", } path = JobServiceClient.data_labeling_job_path(**expected) @@ -6641,22 +6292,24 @@ def test_parse_data_labeling_job_path(): actual = JobServiceClient.parse_data_labeling_job_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = JobServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = JobServiceClient.dataset_path(**expected) @@ -6664,22 +6317,28 @@ def test_parse_dataset_path(): actual = JobServiceClient.parse_dataset_path(path) assert expected == actual + def test_hyperparameter_tuning_job_path(): project = "squid" location = "clam" hyperparameter_tuning_job = "whelk" - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) - actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + actual = JobServiceClient.hyperparameter_tuning_job_path( + project, location, hyperparameter_tuning_job + ) assert expected == actual def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "hyperparameter_tuning_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -6687,22 +6346,24 @@ def test_parse_hyperparameter_tuning_job_path(): actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = JobServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = JobServiceClient.model_path(**expected) @@ -6710,24 +6371,26 @@ def test_parse_model_path(): actual = JobServiceClient.parse_model_path(path) assert expected == actual + def test_trial_path(): project = "squid" location = "clam" study = "whelk" trial = "octopus" - expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) actual = JobServiceClient.trial_path(project, location, study, trial) assert expected == actual def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", - + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", } path = JobServiceClient.trial_path(**expected) @@ -6735,18 +6398,20 @@ def test_parse_trial_path(): actual = JobServiceClient.parse_trial_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = JobServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } path = JobServiceClient.common_billing_account_path(**expected) @@ -6754,18 +6419,18 @@ def test_parse_common_billing_account_path(): actual = JobServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = JobServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = JobServiceClient.common_folder_path(**expected) @@ -6773,18 +6438,18 @@ def test_parse_common_folder_path(): actual = JobServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = JobServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = JobServiceClient.common_organization_path(**expected) @@ -6792,18 +6457,18 @@ def test_parse_common_organization_path(): actual = JobServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = JobServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = JobServiceClient.common_project_path(**expected) @@ -6811,20 +6476,22 @@ def test_parse_common_project_path(): actual = JobServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = JobServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = JobServiceClient.common_location_path(**expected) @@ -6836,17 +6503,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = JobServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 2f1c62f3ef..d1b0b51231 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceAsyncClient +from google.cloud.aiplatform_v1.services.migration_service import ( + MigrationServiceAsyncClient, +) from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceClient from google.cloud.aiplatform_v1.services.migration_service import pagers from google.cloud.aiplatform_v1.services.migration_service import transports @@ -53,7 +55,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -64,36 +70,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -103,7 +126,7 @@ def test_migration_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_client_get_transport_class(): @@ -117,29 +140,44 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) -def test_migration_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -155,7 +193,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -171,7 +209,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -191,13 +229,15 @@ def test_migration_service_client_client_options(client_class, transport_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -210,26 +250,62 @@ def test_migration_service_client_client_options(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -252,10 +328,18 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -276,9 +360,14 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -292,16 +381,23 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -314,16 +410,24 @@ def test_migration_service_client_client_options_scopes(client_class, transport_ client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -338,10 +442,12 @@ def test_migration_service_client_client_options_credentials_file(client_class, def test_migration_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -354,10 +460,12 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -366,12 +474,11 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_migratable_resources(request) @@ -386,7 +493,7 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_migratable_resources_from_dict(): @@ -397,25 +504,27 @@ def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.SearchMigratableResourcesRequest() + @pytest.mark.asyncio -async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): +async def test_search_migratable_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -424,12 +533,14 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.search_migratable_resources(request) @@ -442,7 +553,7 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -451,19 +562,17 @@ async def test_search_migratable_resources_async_from_dict(): def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -475,10 +584,7 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -490,13 +596,15 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) await client.search_migratable_resources(request) @@ -507,49 +615,39 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources( - parent='parent_value', - ) + client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) @@ -561,24 +659,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources( - parent='parent_value', - ) + response = await client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -591,20 +689,17 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -613,17 +708,14 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -636,9 +728,7 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.search_migratable_resources(request={}) @@ -646,18 +736,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in results) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + def test_search_migratable_resources_pages(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -666,17 +756,14 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -687,19 +774,20 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -708,17 +796,14 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -729,25 +814,27 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in responses) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -756,17 +843,14 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -779,14 +863,15 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -795,10 +880,10 @@ def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_migrate_resources(request) @@ -820,25 +905,27 @@ def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.BatchMigrateResourcesRequest() + @pytest.mark.asyncio -async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): +async def test_batch_migrate_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.BatchMigrateResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -847,11 +934,11 @@ async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_migrate_resources(request) @@ -872,20 +959,18 @@ async def test_batch_migrate_resources_async_from_dict(): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_migrate_resources(request) @@ -896,10 +981,7 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -911,13 +993,15 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_migrate_resources(request) @@ -928,29 +1012,30 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -958,23 +1043,33 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -986,19 +1081,25 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -1006,9 +1107,15 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] @pytest.mark.asyncio @@ -1022,8 +1129,14 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -1034,8 +1147,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1054,8 +1166,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1083,13 +1194,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1097,13 +1211,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MigrationServiceGrpcTransport, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) def test_migration_service_base_transport_error(): @@ -1111,13 +1220,15 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1126,9 +1237,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'search_migratable_resources', - 'batch_migrate_resources', - ) + "search_migratable_resources", + "batch_migrate_resources", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1141,23 +1252,28 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1166,11 +1282,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1178,19 +1294,25 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1199,15 +1321,13 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1222,38 +1342,40 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_migration_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1261,12 +1383,11 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1275,12 +1396,22 @@ def test_migration_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1289,7 +1420,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1305,9 +1436,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1321,17 +1450,23 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1348,9 +1483,7 @@ def test_migration_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1363,16 +1496,12 @@ def test_migration_service_transport_channel_mtls_with_adc( def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1380,16 +1509,12 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1400,17 +1525,20 @@ def test_annotated_dataset_path(): dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) - actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", - + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1418,22 +1546,24 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1441,20 +1571,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "squid" dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", - + "project": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1462,22 +1594,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "oyster" location = "nudibranch" dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", - "dataset": "nautilus", - + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1485,22 +1619,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - + "project": "clam", + "location": "whelk", + "model": "octopus", } path = MigrationServiceClient.model_path(**expected) @@ -1508,22 +1644,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", - + "project": "mussel", + "location": "winkle", + "model": "nautilus", } path = MigrationServiceClient.model_path(**expected) @@ -1531,22 +1669,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", - + "project": "clam", + "model": "whelk", + "version": "octopus", } path = MigrationServiceClient.version_path(**expected) @@ -1554,18 +1694,20 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", - + "billing_account": "nudibranch", } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1573,18 +1715,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", - + "folder": "mussel", } path = MigrationServiceClient.common_folder_path(**expected) @@ -1592,18 +1734,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", - + "organization": "nautilus", } path = MigrationServiceClient.common_organization_path(**expected) @@ -1611,18 +1753,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", - + "project": "abalone", } path = MigrationServiceClient.common_project_path(**expected) @@ -1630,20 +1772,22 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", - + "project": "whelk", + "location": "octopus", } path = MigrationServiceClient.common_location_path(**expected) @@ -1655,17 +1799,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index 0011bd1129..f74aea2dea 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -64,7 +64,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -75,36 +79,45 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [ - ModelServiceClient, - ModelServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) def test_model_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - ModelServiceClient, - ModelServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -114,7 +127,7 @@ def test_model_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_client_get_transport_class(): @@ -128,29 +141,42 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -def test_model_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -166,7 +192,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -182,7 +208,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -202,13 +228,15 @@ def test_model_service_client_client_options(client_class, transport_class, tran client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -221,26 +249,50 @@ def test_model_service_client_client_options(client_class, transport_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -263,10 +315,18 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -287,9 +347,14 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -303,16 +368,23 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -325,16 +397,24 @@ def test_model_service_client_client_options_scopes(client_class, transport_clas client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -349,11 +429,11 @@ def test_model_service_client_client_options_credentials_file(client_class, tran def test_model_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -365,10 +445,11 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): +def test_upload_model( + transport: str = "grpc", request_type=model_service.UploadModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -376,11 +457,9 @@ def test_upload_model(transport: str = 'grpc', request_type=model_service.Upload request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.upload_model(request) @@ -402,25 +481,24 @@ def test_upload_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: client.upload_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UploadModelRequest() + @pytest.mark.asyncio -async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UploadModelRequest): +async def test_upload_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -428,12 +506,10 @@ async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.upload_model(request) @@ -454,20 +530,16 @@ async def test_upload_model_async_from_dict(): def test_upload_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.upload_model(request) @@ -478,28 +550,23 @@ def test_upload_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.upload_model(request) @@ -510,29 +577,21 @@ async def test_upload_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_upload_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -540,47 +599,40 @@ def test_upload_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") def test_upload_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) @pytest.mark.asyncio async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -588,31 +640,28 @@ async def test_upload_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") @pytest.mark.asyncio async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) -def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): +def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -620,31 +669,21 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.get_model(request) @@ -659,25 +698,31 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_model_from_dict(): @@ -688,25 +733,24 @@ def test_get_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: client.get_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelRequest() + @pytest.mark.asyncio -async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): +async def test_get_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -714,22 +758,28 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.get_model(request) @@ -742,25 +792,31 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -769,19 +825,15 @@ async def test_get_model_async_from_dict(): def test_get_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = model.Model() client.get_model(request) @@ -793,27 +845,20 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -825,99 +870,79 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model( - name='name_value', - ) + client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model( - name='name_value', - ) + response = await client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) -def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): +def test_list_models( + transport: str = "grpc", request_type=model_service.ListModelsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -925,13 +950,10 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_models(request) @@ -946,7 +968,7 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_models_from_dict(): @@ -957,25 +979,24 @@ def test_list_models_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: client.list_models() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelsRequest() + @pytest.mark.asyncio -async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): +async def test_list_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -983,13 +1004,11 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_models(request) @@ -1002,7 +1021,7 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1011,19 +1030,15 @@ async def test_list_models_async_from_dict(): def test_list_models_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: call.return_value = model_service.ListModelsResponse() client.list_models(request) @@ -1035,28 +1050,23 @@ def test_list_models_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) await client.list_models(request) @@ -1067,138 +1077,98 @@ async def test_list_models_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_models_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_models( - parent='parent_value', - ) + client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_models_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_models( - parent='parent_value', - ) + response = await client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) def test_list_models_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_models(request={}) @@ -1206,147 +1176,96 @@ def test_list_models_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) + assert all(isinstance(i, model.Model) for i in results) + def test_list_models_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_models_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) - for i in responses) + assert all(isinstance(i, model.Model) for i in responses) + @pytest.mark.asyncio async def test_list_models_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_models(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): +def test_update_model( + transport: str = "grpc", request_type=model_service.UpdateModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1354,31 +1273,21 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.update_model(request) @@ -1393,25 +1302,31 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_model_from_dict(): @@ -1422,25 +1337,24 @@ def test_update_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: client.update_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UpdateModelRequest() + @pytest.mark.asyncio -async def test_update_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateModelRequest): +async def test_update_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1448,22 +1362,28 @@ async def test_update_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.update_model(request) @@ -1476,25 +1396,31 @@ async def test_update_model_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -1503,19 +1429,15 @@ async def test_update_model_async_from_dict(): def test_update_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = gca_model.Model() client.update_model(request) @@ -1527,27 +1449,20 @@ def test_update_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) await client.update_model(request) @@ -1559,29 +1474,22 @@ async def test_update_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] def test_update_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1589,36 +1497,30 @@ def test_update_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() @@ -1626,8 +1528,8 @@ async def test_update_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1635,31 +1537,30 @@ async def test_update_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): +def test_delete_model( + transport: str = "grpc", request_type=model_service.DeleteModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1667,11 +1568,9 @@ def test_delete_model(transport: str = 'grpc', request_type=model_service.Delete request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_model(request) @@ -1693,25 +1592,24 @@ def test_delete_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: client.delete_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.DeleteModelRequest() + @pytest.mark.asyncio -async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteModelRequest): +async def test_delete_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1719,12 +1617,10 @@ async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_model(request) @@ -1745,20 +1641,16 @@ async def test_delete_model_async_from_dict(): def test_delete_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_model(request) @@ -1769,28 +1661,23 @@ def test_delete_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_model(request) @@ -1801,101 +1688,81 @@ async def test_delete_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model( - name='name_value', - ) + client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_model( - name='name_value', - ) + response = await client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) -def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): +def test_export_model( + transport: str = "grpc", request_type=model_service.ExportModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1903,11 +1770,9 @@ def test_export_model(transport: str = 'grpc', request_type=model_service.Export request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_model(request) @@ -1929,25 +1794,24 @@ def test_export_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: client.export_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ExportModelRequest() + @pytest.mark.asyncio -async def test_export_model_async(transport: str = 'grpc_asyncio', request_type=model_service.ExportModelRequest): +async def test_export_model_async( + transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1955,12 +1819,10 @@ async def test_export_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_model(request) @@ -1981,20 +1843,16 @@ async def test_export_model_async_from_dict(): def test_export_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_model(request) @@ -2005,28 +1863,23 @@ def test_export_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_model(request) @@ -2037,29 +1890,24 @@ async def test_export_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_export_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -2067,47 +1915,47 @@ def test_export_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) def test_export_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) -@pytest.mark.asyncio -async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) +@pytest.mark.asyncio +async def test_export_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -2115,31 +1963,34 @@ async def test_export_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) @pytest.mark.asyncio async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) -def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): +def test_get_model_evaluation( + transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2148,16 +1999,13 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - - slice_dimensions=['slice_dimensions_value'], - + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], ) response = client.get_model_evaluation(request) @@ -2172,11 +2020,11 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] def test_get_model_evaluation_from_dict(): @@ -2187,25 +2035,27 @@ def test_get_model_evaluation_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: client.get_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationRequest() + @pytest.mark.asyncio -async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationRequest): +async def test_get_model_evaluation_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2214,14 +2064,16 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - slice_dimensions=['slice_dimensions_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + ) response = await client.get_model_evaluation(request) @@ -2234,11 +2086,11 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', reque # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] @pytest.mark.asyncio @@ -2247,19 +2099,17 @@ async def test_get_model_evaluation_async_from_dict(): def test_get_model_evaluation_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: call.return_value = model_evaluation.ModelEvaluation() client.get_model_evaluation(request) @@ -2271,28 +2121,25 @@ def test_get_model_evaluation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) await client.get_model_evaluation(request) @@ -2303,99 +2150,85 @@ async def test_get_model_evaluation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation( - name='name_value', - ) + client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation( - name='name_value', - ) + response = await client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) -def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): +def test_list_model_evaluations( + transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2404,12 +2237,11 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluations(request) @@ -2424,7 +2256,7 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluations_from_dict(): @@ -2435,25 +2267,27 @@ def test_list_model_evaluations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: client.list_model_evaluations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationsRequest() + @pytest.mark.asyncio -async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationsRequest): +async def test_list_model_evaluations_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationsRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2462,12 +2296,14 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluations(request) @@ -2480,7 +2316,7 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', req # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2489,19 +2325,17 @@ async def test_list_model_evaluations_async_from_dict(): def test_list_model_evaluations_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationsResponse() client.list_model_evaluations(request) @@ -2513,28 +2347,25 @@ def test_list_model_evaluations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) await client.list_model_evaluations(request) @@ -2545,104 +2376,87 @@ async def test_list_model_evaluations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluations_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluations( - parent='parent_value', - ) + client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluations( - parent='parent_value', - ) + response = await client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) def test_list_model_evaluations_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2651,17 +2465,14 @@ def test_list_model_evaluations_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2674,9 +2485,7 @@ def test_list_model_evaluations_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluations(request={}) @@ -2684,18 +2493,16 @@ def test_list_model_evaluations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in results) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) + def test_list_model_evaluations_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2704,17 +2511,14 @@ def test_list_model_evaluations_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2725,19 +2529,20 @@ def test_list_model_evaluations_pages(): RuntimeError, ) pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2746,17 +2551,14 @@ async def test_list_model_evaluations_async_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2767,25 +2569,25 @@ async def test_list_model_evaluations_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in responses) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) + @pytest.mark.asyncio async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2794,17 +2596,14 @@ async def test_list_model_evaluations_async_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2817,14 +2616,15 @@ async def test_list_model_evaluations_async_pages(): pages = [] async for page_ in (await client.list_model_evaluations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): +def test_get_model_evaluation_slice( + transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2833,14 +2633,11 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - + name="name_value", metrics_schema_uri="metrics_schema_uri_value", ) response = client.get_model_evaluation_slice(request) @@ -2855,9 +2652,9 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" def test_get_model_evaluation_slice_from_dict(): @@ -2868,25 +2665,27 @@ def test_get_model_evaluation_slice_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: client.get_model_evaluation_slice() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationSliceRequest() + @pytest.mark.asyncio -async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationSliceRequest): +async def test_get_model_evaluation_slice_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationSliceRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2895,13 +2694,14 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + ) response = await client.get_model_evaluation_slice(request) @@ -2914,9 +2714,9 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" @pytest.mark.asyncio @@ -2925,19 +2725,17 @@ async def test_get_model_evaluation_slice_async_from_dict(): def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: call.return_value = model_evaluation_slice.ModelEvaluationSlice() client.get_model_evaluation_slice(request) @@ -2949,28 +2747,25 @@ def test_get_model_evaluation_slice_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) await client.get_model_evaluation_slice(request) @@ -2981,99 +2776,85 @@ async def test_get_model_evaluation_slice_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation_slice( - name='name_value', - ) + client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice( - name='name_value', - ) + response = await client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) -def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): +def test_list_model_evaluation_slices( + transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3082,12 +2863,11 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluation_slices(request) @@ -3102,7 +2882,7 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluation_slices_from_dict(): @@ -3113,25 +2893,27 @@ def test_list_model_evaluation_slices_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: client.list_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationSlicesRequest() + @pytest.mark.asyncio -async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationSlicesRequest): +async def test_list_model_evaluation_slices_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationSlicesRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3140,12 +2922,14 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluation_slices(request) @@ -3158,7 +2942,7 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3167,19 +2951,17 @@ async def test_list_model_evaluation_slices_async_from_dict(): def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationSlicesResponse() client.list_model_evaluation_slices(request) @@ -3191,28 +2973,25 @@ def test_list_model_evaluation_slices_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) await client.list_model_evaluation_slices(request) @@ -3223,104 +3002,87 @@ async def test_list_model_evaluation_slices_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluation_slices( - parent='parent_value', - ) + client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices( - parent='parent_value', - ) + response = await client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3329,17 +3091,16 @@ def test_list_model_evaluation_slices_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3352,9 +3113,7 @@ def test_list_model_evaluation_slices_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluation_slices(request={}) @@ -3362,18 +3121,18 @@ def test_list_model_evaluation_slices_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in results) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results + ) + def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3382,17 +3141,16 @@ def test_list_model_evaluation_slices_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3403,19 +3161,20 @@ def test_list_model_evaluation_slices_pages(): RuntimeError, ) pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3424,17 +3183,16 @@ async def test_list_model_evaluation_slices_async_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3445,25 +3203,28 @@ async def test_list_model_evaluation_slices_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses + ) + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3472,17 +3233,16 @@ async def test_list_model_evaluation_slices_async_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3493,9 +3253,11 @@ async def test_list_model_evaluation_slices_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_model_evaluation_slices(request={})).pages: + async for page_ in ( + await client.list_model_evaluation_slices(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3506,8 +3268,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3526,8 +3287,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3555,13 +3315,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3569,13 +3332,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.ModelServiceGrpcTransport, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) def test_model_service_base_transport_error(): @@ -3583,13 +3341,15 @@ def test_model_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3598,17 +3358,17 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'upload_model', - 'get_model', - 'list_models', - 'update_model', - 'delete_model', - 'export_model', - 'get_model_evaluation', - 'list_model_evaluations', - 'get_model_evaluation_slice', - 'list_model_evaluation_slices', - ) + "upload_model", + "get_model", + "list_models", + "update_model", + "delete_model", + "export_model", + "get_model_evaluation", + "list_model_evaluations", + "get_model_evaluation_slice", + "list_model_evaluation_slices", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3621,23 +3381,28 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -3646,11 +3411,11 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) ModelServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3658,19 +3423,22 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3679,15 +3447,13 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3702,38 +3468,40 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_host_with_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_model_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3741,12 +3509,11 @@ def test_model_service_grpc_transport_channel(): def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3755,12 +3522,17 @@ def test_model_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3769,7 +3541,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3785,9 +3557,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3801,17 +3571,20 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3828,9 +3601,7 @@ def test_model_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3843,16 +3614,12 @@ def test_model_service_transport_channel_mtls_with_adc( def test_model_service_grpc_lro_client(): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3860,16 +3627,12 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3880,17 +3643,18 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = ModelServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = ModelServiceClient.endpoint_path(**expected) @@ -3898,22 +3662,24 @@ def test_parse_endpoint_path(): actual = ModelServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = ModelServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = ModelServiceClient.model_path(**expected) @@ -3921,24 +3687,28 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual + def test_model_evaluation_path(): project = "squid" location = "clam" model = "whelk" evaluation = "octopus" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) - actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + actual = ModelServiceClient.model_evaluation_path( + project, location, model, evaluation + ) assert expected == actual def test_parse_model_evaluation_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "model": "cuttlefish", - "evaluation": "mussel", - + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", } path = ModelServiceClient.model_evaluation_path(**expected) @@ -3946,6 +3716,7 @@ def test_parse_model_evaluation_path(): actual = ModelServiceClient.parse_model_evaluation_path(path) assert expected == actual + def test_model_evaluation_slice_path(): project = "winkle" location = "nautilus" @@ -3953,19 +3724,26 @@ def test_model_evaluation_slice_path(): evaluation = "abalone" slice = "squid" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) - actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + actual = ModelServiceClient.model_evaluation_slice_path( + project, location, model, evaluation, slice + ) assert expected == actual def test_parse_model_evaluation_slice_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - "evaluation": "oyster", - "slice": "nudibranch", - + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", } path = ModelServiceClient.model_evaluation_slice_path(**expected) @@ -3973,22 +3751,26 @@ def test_parse_model_evaluation_slice_path(): actual = ModelServiceClient.parse_model_evaluation_slice_path(path) assert expected == actual + def test_training_pipeline_path(): project = "cuttlefish" location = "mussel" training_pipeline = "winkle" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = ModelServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "nautilus", - "location": "scallop", - "training_pipeline": "abalone", - + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", } path = ModelServiceClient.training_pipeline_path(**expected) @@ -3996,18 +3778,20 @@ def test_parse_training_pipeline_path(): actual = ModelServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = ModelServiceClient.common_billing_account_path(**expected) @@ -4015,18 +3799,18 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = ModelServiceClient.common_folder_path(**expected) @@ -4034,18 +3818,18 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = ModelServiceClient.common_organization_path(**expected) @@ -4053,18 +3837,18 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = ModelServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = ModelServiceClient.common_project_path(**expected) @@ -4072,20 +3856,22 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = ModelServiceClient.common_location_path(**expected) @@ -4097,17 +3883,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index de2ff38ef2..d0079aae4d 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.pipeline_service import PipelineServiceAsyncClient +from google.cloud.aiplatform_v1.services.pipeline_service import ( + PipelineServiceAsyncClient, +) from google.cloud.aiplatform_v1.services.pipeline_service import PipelineServiceClient from google.cloud.aiplatform_v1.services.pipeline_service import pagers from google.cloud.aiplatform_v1.services.pipeline_service import transports @@ -66,7 +68,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -77,36 +83,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - PipelineServiceClient, - PipelineServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] +) def test_pipeline_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - PipelineServiceClient, - PipelineServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] +) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -116,7 +138,7 @@ def test_pipeline_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_client_get_transport_class(): @@ -130,29 +152,44 @@ def test_pipeline_service_client_get_transport_class(): assert transport == transports.PipelineServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) -def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -168,7 +205,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -184,7 +221,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -204,13 +241,15 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -223,26 +262,62 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "true", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "false", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_pipeline_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -265,10 +340,18 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -289,9 +372,14 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -305,16 +393,23 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -327,16 +422,24 @@ def test_pipeline_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -351,10 +454,12 @@ def test_pipeline_service_client_client_options_credentials_file(client_class, t def test_pipeline_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = PipelineServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -367,10 +472,11 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): +def test_create_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -379,18 +485,14 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.create_training_pipeline(request) @@ -405,11 +507,11 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -422,25 +524,27 @@ def test_create_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: client.create_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CreateTrainingPipelineRequest): +async def test_create_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CreateTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -449,15 +553,17 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.create_training_pipeline(request) @@ -470,11 +576,11 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -485,19 +591,17 @@ async def test_create_training_pipeline_async_from_dict(): def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: call.return_value = gca_training_pipeline.TrainingPipeline() client.create_training_pipeline(request) @@ -509,28 +613,25 @@ def test_create_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) await client.create_training_pipeline(request) @@ -541,29 +642,24 @@ async def test_create_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -571,45 +667,45 @@ def test_create_training_pipeline_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -617,31 +713,32 @@ async def test_create_training_pipeline_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) -def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): +def test_get_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -650,18 +747,14 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.get_training_pipeline(request) @@ -676,11 +769,11 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -693,25 +786,27 @@ def test_get_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: client.get_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.GetTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.GetTrainingPipelineRequest): +async def test_get_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.GetTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -720,15 +815,17 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.get_training_pipeline(request) @@ -741,11 +838,11 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -756,19 +853,17 @@ async def test_get_training_pipeline_async_from_dict(): def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: call.return_value = training_pipeline.TrainingPipeline() client.get_training_pipeline(request) @@ -780,28 +875,25 @@ def test_get_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) await client.get_training_pipeline(request) @@ -812,99 +904,85 @@ async def test_get_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_training_pipeline( - name='name_value', - ) + client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_training_pipeline( - name='name_value', - ) + response = await client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) -def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): +def test_list_training_pipelines( + transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -913,12 +991,11 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_training_pipelines(request) @@ -933,7 +1010,7 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_training_pipelines_from_dict(): @@ -944,25 +1021,27 @@ def test_list_training_pipelines_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: client.list_training_pipelines() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + @pytest.mark.asyncio -async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.ListTrainingPipelinesRequest): +async def test_list_training_pipelines_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.ListTrainingPipelinesRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -971,12 +1050,14 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_training_pipelines(request) @@ -989,7 +1070,7 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', re # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -998,19 +1079,17 @@ async def test_list_training_pipelines_async_from_dict(): def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: call.return_value = pipeline_service.ListTrainingPipelinesResponse() client.list_training_pipelines(request) @@ -1022,28 +1101,25 @@ def test_list_training_pipelines_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) await client.list_training_pipelines(request) @@ -1054,104 +1130,87 @@ async def test_list_training_pipelines_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_training_pipelines_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_training_pipelines( - parent='parent_value', - ) + client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_training_pipelines( - parent='parent_value', - ) + response = await client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) def test_list_training_pipelines_pager(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1160,17 +1219,14 @@ def test_list_training_pipelines_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1183,9 +1239,7 @@ def test_list_training_pipelines_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_training_pipelines(request={}) @@ -1193,18 +1247,16 @@ def test_list_training_pipelines_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in results) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) + def test_list_training_pipelines_pages(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1213,17 +1265,14 @@ def test_list_training_pipelines_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1234,19 +1283,20 @@ def test_list_training_pipelines_pages(): RuntimeError, ) pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1255,17 +1305,14 @@ async def test_list_training_pipelines_async_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1276,25 +1323,25 @@ async def test_list_training_pipelines_async_pager(): RuntimeError, ) async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in responses) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) + @pytest.mark.asyncio async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1303,17 +1350,14 @@ async def test_list_training_pipelines_async_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1326,14 +1370,15 @@ async def test_list_training_pipelines_async_pages(): pages = [] async for page_ in (await client.list_training_pipelines(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): +def test_delete_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1342,10 +1387,10 @@ def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_training_pipeline(request) @@ -1367,25 +1412,27 @@ def test_delete_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: client.delete_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.DeleteTrainingPipelineRequest): +async def test_delete_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.DeleteTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1394,11 +1441,11 @@ async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_training_pipeline(request) @@ -1419,20 +1466,18 @@ async def test_delete_training_pipeline_async_from_dict(): def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_training_pipeline(request) @@ -1443,28 +1488,25 @@ def test_delete_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_training_pipeline(request) @@ -1475,101 +1517,85 @@ async def test_delete_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_training_pipeline( - name='name_value', - ) + client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_training_pipeline( - name='name_value', - ) + response = await client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) -def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): +def test_cancel_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1578,8 +1604,8 @@ def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1603,25 +1629,27 @@ def test_cancel_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: client.cancel_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CancelTrainingPipelineRequest): +async def test_cancel_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CancelTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1630,8 +1658,8 @@ async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1653,19 +1681,17 @@ async def test_cancel_training_pipeline_async_from_dict(): def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = None client.cancel_training_pipeline(request) @@ -1677,27 +1703,22 @@ def test_cancel_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_training_pipeline(request) @@ -1709,92 +1730,75 @@ async def test_cancel_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_training_pipeline( - name='name_value', - ) + client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_training_pipeline( - name='name_value', - ) + response = await client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @@ -1805,8 +1809,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1825,8 +1828,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1854,13 +1856,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1868,13 +1873,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.PipelineServiceGrpcTransport, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) def test_pipeline_service_base_transport_error(): @@ -1882,13 +1882,15 @@ def test_pipeline_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_pipeline_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1897,12 +1899,12 @@ def test_pipeline_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_training_pipeline', - 'get_training_pipeline', - 'list_training_pipelines', - 'delete_training_pipeline', - 'cancel_training_pipeline', - ) + "create_training_pipeline", + "get_training_pipeline", + "list_training_pipelines", + "delete_training_pipeline", + "cancel_training_pipeline", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1915,23 +1917,28 @@ def test_pipeline_service_base_transport(): def test_pipeline_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_pipeline_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport() @@ -1940,11 +1947,11 @@ def test_pipeline_service_base_transport_with_adc(): def test_pipeline_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PipelineServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1952,19 +1959,25 @@ def test_pipeline_service_auth_adc(): def test_pipeline_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) -def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1973,15 +1986,13 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1996,38 +2007,40 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_host_with_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_pipeline_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2035,12 +2048,11 @@ def test_pipeline_service_grpc_transport_channel(): def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2049,12 +2061,22 @@ def test_pipeline_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2063,7 +2085,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2079,9 +2101,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2095,17 +2115,23 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) -def test_pipeline_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2122,9 +2148,7 @@ def test_pipeline_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2137,16 +2161,12 @@ def test_pipeline_service_transport_channel_mtls_with_adc( def test_pipeline_service_grpc_lro_client(): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2154,16 +2174,12 @@ def test_pipeline_service_grpc_lro_client(): def test_pipeline_service_grpc_lro_async_client(): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2174,17 +2190,18 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = PipelineServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = PipelineServiceClient.endpoint_path(**expected) @@ -2192,22 +2209,24 @@ def test_parse_endpoint_path(): actual = PipelineServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = PipelineServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = PipelineServiceClient.model_path(**expected) @@ -2215,22 +2234,26 @@ def test_parse_model_path(): actual = PipelineServiceClient.parse_model_path(path) assert expected == actual + def test_training_pipeline_path(): project = "squid" location = "clam" training_pipeline = "whelk" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = PipelineServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", - + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2238,18 +2261,20 @@ def test_parse_training_pipeline_path(): actual = PipelineServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = PipelineServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2257,18 +2282,18 @@ def test_parse_common_billing_account_path(): actual = PipelineServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = PipelineServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = PipelineServiceClient.common_folder_path(**expected) @@ -2276,18 +2301,18 @@ def test_parse_common_folder_path(): actual = PipelineServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = PipelineServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = PipelineServiceClient.common_organization_path(**expected) @@ -2295,18 +2320,18 @@ def test_parse_common_organization_path(): actual = PipelineServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = PipelineServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = PipelineServiceClient.common_project_path(**expected) @@ -2314,20 +2339,22 @@ def test_parse_common_project_path(): actual = PipelineServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = PipelineServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = PipelineServiceClient.common_location_path(**expected) @@ -2339,17 +2366,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = PipelineServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py index 4017a16cc3..339187f22a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient -from google.cloud.aiplatform_v1.services.specialist_pool_service import SpecialistPoolServiceClient +from google.cloud.aiplatform_v1.services.specialist_pool_service import ( + SpecialistPoolServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.specialist_pool_service import ( + SpecialistPoolServiceClient, +) from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1.services.specialist_pool_service import transports from google.cloud.aiplatform_v1.types import operation as gca_operation @@ -56,7 +60,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -67,36 +75,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - SpecialistPoolServiceClient, - SpecialistPoolServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] +) def test_specialist_pool_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - SpecialistPoolServiceClient, - SpecialistPoolServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] +) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -106,7 +131,7 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_client_get_transport_class(): @@ -120,29 +145,48 @@ def test_specialist_pool_service_client_get_transport_class(): assert transport == transports.SpecialistPoolServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) -def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -158,7 +202,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -174,7 +218,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -194,13 +238,15 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -213,26 +259,62 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "true", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "false", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_specialist_pool_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -255,10 +337,18 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -279,9 +369,14 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -295,16 +390,27 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -317,16 +423,28 @@ def test_specialist_pool_service_client_client_options_scopes(client_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -341,10 +459,12 @@ def test_specialist_pool_service_client_client_options_credentials_file(client_c def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = SpecialistPoolServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -357,10 +477,12 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): +def test_create_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -369,10 +491,10 @@ def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_specialist_pool(request) @@ -394,25 +516,27 @@ def test_create_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: client.create_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.CreateSpecialistPoolRequest): +async def test_create_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -421,11 +545,11 @@ async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_specialist_pool(request) @@ -453,13 +577,13 @@ def test_create_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_specialist_pool(request) @@ -470,10 +594,7 @@ def test_create_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -485,13 +606,15 @@ async def test_create_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_specialist_pool(request) @@ -502,10 +625,7 @@ async def test_create_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_specialist_pool_flattened(): @@ -515,16 +635,16 @@ def test_create_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -532,9 +652,11 @@ def test_create_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) def test_create_specialist_pool_flattened_error(): @@ -547,8 +669,8 @@ def test_create_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) @@ -560,19 +682,19 @@ async def test_create_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -580,9 +702,11 @@ async def test_create_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) @pytest.mark.asyncio @@ -596,15 +720,17 @@ async def test_create_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) -def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): +def test_get_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -613,20 +739,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", specialist_managers_count=2662, - - specialist_manager_emails=['specialist_manager_emails_value'], - - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], ) response = client.get_specialist_pool(request) @@ -641,15 +762,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] def test_get_specialist_pool_from_dict(): @@ -660,25 +781,27 @@ def test_get_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: client.get_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.GetSpecialistPoolRequest): +async def test_get_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -687,16 +810,18 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( - name='name_value', - display_name='display_name_value', - specialist_managers_count=2662, - specialist_manager_emails=['specialist_manager_emails_value'], - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + ) response = await client.get_specialist_pool(request) @@ -709,15 +834,15 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] @pytest.mark.asyncio @@ -733,12 +858,12 @@ def test_get_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: call.return_value = specialist_pool.SpecialistPool() client.get_specialist_pool(request) @@ -750,10 +875,7 @@ def test_get_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -765,13 +887,15 @@ async def test_get_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) await client.get_specialist_pool(request) @@ -782,10 +906,7 @@ async def test_get_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_specialist_pool_flattened(): @@ -795,23 +916,21 @@ def test_get_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_specialist_pool( - name='name_value', - ) + client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_specialist_pool_flattened_error(): @@ -823,8 +942,7 @@ def test_get_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) @@ -836,24 +954,24 @@ async def test_get_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_specialist_pool( - name='name_value', - ) + response = await client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -866,15 +984,16 @@ async def test_get_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) -def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): +def test_list_specialist_pools( + transport: str = "grpc", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -883,12 +1002,11 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_specialist_pools(request) @@ -903,7 +1021,7 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_specialist_pools_from_dict(): @@ -914,25 +1032,27 @@ def test_list_specialist_pools_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: client.list_specialist_pools() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + @pytest.mark.asyncio -async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.ListSpecialistPoolsRequest): +async def test_list_specialist_pools_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -941,12 +1061,14 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_specialist_pools(request) @@ -959,7 +1081,7 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -975,12 +1097,12 @@ def test_list_specialist_pools_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() client.list_specialist_pools(request) @@ -992,10 +1114,7 @@ def test_list_specialist_pools_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1007,13 +1126,15 @@ async def test_list_specialist_pools_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) await client.list_specialist_pools(request) @@ -1024,10 +1145,7 @@ async def test_list_specialist_pools_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_specialist_pools_flattened(): @@ -1037,23 +1155,21 @@ def test_list_specialist_pools_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_specialist_pools( - parent='parent_value', - ) + client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_specialist_pools_flattened_error(): @@ -1065,8 +1181,7 @@ def test_list_specialist_pools_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) @@ -1078,24 +1193,24 @@ async def test_list_specialist_pools_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_specialist_pools( - parent='parent_value', - ) + response = await client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -1108,20 +1223,17 @@ async def test_list_specialist_pools_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1130,17 +1242,14 @@ def test_list_specialist_pools_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1153,9 +1262,7 @@ def test_list_specialist_pools_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_specialist_pools(request={}) @@ -1163,18 +1270,16 @@ def test_list_specialist_pools_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in results) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) + def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1183,17 +1288,14 @@ def test_list_specialist_pools_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1204,9 +1306,10 @@ def test_list_specialist_pools_pages(): RuntimeError, ) pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_specialist_pools_async_pager(): client = SpecialistPoolServiceAsyncClient( @@ -1215,8 +1318,10 @@ async def test_list_specialist_pools_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1225,17 +1330,14 @@ async def test_list_specialist_pools_async_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1246,14 +1348,14 @@ async def test_list_specialist_pools_async_pager(): RuntimeError, ) async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in responses) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) + @pytest.mark.asyncio async def test_list_specialist_pools_async_pages(): @@ -1263,8 +1365,10 @@ async def test_list_specialist_pools_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1273,17 +1377,14 @@ async def test_list_specialist_pools_async_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1296,14 +1397,16 @@ async def test_list_specialist_pools_async_pages(): pages = [] async for page_ in (await client.list_specialist_pools(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): +def test_delete_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1312,10 +1415,10 @@ def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_specialist_pool(request) @@ -1337,25 +1440,27 @@ def test_delete_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: client.delete_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): +async def test_delete_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1364,11 +1469,11 @@ async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_specialist_pool(request) @@ -1396,13 +1501,13 @@ def test_delete_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_specialist_pool(request) @@ -1413,10 +1518,7 @@ def test_delete_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1428,13 +1530,15 @@ async def test_delete_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_specialist_pool(request) @@ -1445,10 +1549,7 @@ async def test_delete_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_specialist_pool_flattened(): @@ -1458,23 +1559,21 @@ def test_delete_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_specialist_pool( - name='name_value', - ) + client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_specialist_pool_flattened_error(): @@ -1486,8 +1585,7 @@ def test_delete_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) @@ -1499,26 +1597,24 @@ async def test_delete_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_specialist_pool( - name='name_value', - ) + response = await client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -1531,15 +1627,16 @@ async def test_delete_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) -def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): +def test_update_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1548,10 +1645,10 @@ def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.update_specialist_pool(request) @@ -1573,25 +1670,27 @@ def test_update_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: client.update_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): +async def test_update_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1600,11 +1699,11 @@ async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.update_specialist_pool(request) @@ -1632,13 +1731,13 @@ def test_update_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.update_specialist_pool(request) @@ -1650,9 +1749,9 @@ def test_update_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1664,13 +1763,15 @@ async def test_update_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.update_specialist_pool(request) @@ -1682,9 +1783,9 @@ async def test_update_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] def test_update_specialist_pool_flattened(): @@ -1694,16 +1795,16 @@ def test_update_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1711,9 +1812,11 @@ def test_update_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_specialist_pool_flattened_error(): @@ -1726,8 +1829,8 @@ def test_update_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1739,19 +1842,19 @@ async def test_update_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1759,9 +1862,11 @@ async def test_update_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -1775,8 +1880,8 @@ async def test_update_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1787,8 +1892,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1807,8 +1911,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1836,13 +1939,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1853,10 +1959,7 @@ def test_transport_grpc_default(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), ) - assert isinstance( - client.transport, - transports.SpecialistPoolServiceGrpcTransport, - ) + assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) def test_specialist_pool_service_base_transport_error(): @@ -1864,13 +1967,15 @@ def test_specialist_pool_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_specialist_pool_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1879,12 +1984,12 @@ def test_specialist_pool_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_specialist_pool', - 'get_specialist_pool', - 'list_specialist_pools', - 'delete_specialist_pool', - 'update_specialist_pool', - ) + "create_specialist_pool", + "get_specialist_pool", + "list_specialist_pools", + "delete_specialist_pool", + "update_specialist_pool", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1897,23 +2002,28 @@ def test_specialist_pool_service_base_transport(): def test_specialist_pool_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_specialist_pool_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport() @@ -1922,11 +2032,11 @@ def test_specialist_pool_service_base_transport_with_adc(): def test_specialist_pool_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) SpecialistPoolServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1934,18 +2044,26 @@ def test_specialist_pool_service_auth_adc(): def test_specialist_pool_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( - transport_class + transport_class, ): cred = credentials.AnonymousCredentials() @@ -1955,15 +2073,13 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1978,38 +2094,40 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_host_with_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2017,12 +2135,11 @@ def test_specialist_pool_service_grpc_transport_channel(): def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2031,12 +2148,22 @@ def test_specialist_pool_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2045,7 +2172,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2061,9 +2188,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2077,17 +2202,23 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) -def test_specialist_pool_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2104,9 +2235,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2119,16 +2248,12 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( def test_specialist_pool_service_grpc_lro_client(): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2136,16 +2261,12 @@ def test_specialist_pool_service_grpc_lro_client(): def test_specialist_pool_service_grpc_lro_async_client(): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2156,17 +2277,20 @@ def test_specialist_pool_path(): location = "clam" specialist_pool = "whelk" - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) - actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + actual = SpecialistPoolServiceClient.specialist_pool_path( + project, location, specialist_pool + ) assert expected == actual def test_parse_specialist_pool_path(): expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", - + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", } path = SpecialistPoolServiceClient.specialist_pool_path(**expected) @@ -2174,18 +2298,20 @@ def test_parse_specialist_pool_path(): actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = SpecialistPoolServiceClient.common_billing_account_path(**expected) @@ -2193,18 +2319,18 @@ def test_parse_common_billing_account_path(): actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = SpecialistPoolServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = SpecialistPoolServiceClient.common_folder_path(**expected) @@ -2212,18 +2338,18 @@ def test_parse_common_folder_path(): actual = SpecialistPoolServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = SpecialistPoolServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = SpecialistPoolServiceClient.common_organization_path(**expected) @@ -2231,18 +2357,18 @@ def test_parse_common_organization_path(): actual = SpecialistPoolServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = SpecialistPoolServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = SpecialistPoolServiceClient.common_project_path(**expected) @@ -2250,20 +2376,22 @@ def test_parse_common_project_path(): actual = SpecialistPoolServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = SpecialistPoolServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = SpecialistPoolServiceClient.common_location_path(**expected) @@ -2275,17 +2403,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = SpecialistPoolServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/__init__.py b/tests/unit/gapic/aiplatform_v1beta1/__init__.py index 6a73015364..42ffdf2bc4 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/__init__.py +++ b/tests/unit/gapic/aiplatform_v1beta1/__init__.py @@ -1,4 +1,3 @@ - # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index eb48bd6ebb..5a3818dc9d 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceClient +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceClient, +) from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.services.dataset_service import transports from google.cloud.aiplatform_v1beta1.types import annotation @@ -63,7 +67,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -74,36 +82,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - DatasetServiceClient, - DatasetServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] +) def test_dataset_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - DatasetServiceClient, - DatasetServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] +) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -113,7 +137,7 @@ def test_dataset_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_client_get_transport_class(): @@ -127,29 +151,44 @@ def test_dataset_service_client_get_transport_class(): assert transport == transports.DatasetServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) -def test_dataset_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -165,7 +204,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -181,7 +220,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -201,13 +240,15 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -220,26 +261,52 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_dataset_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -262,10 +329,18 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -286,9 +361,14 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -302,16 +382,23 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -324,16 +411,24 @@ def test_dataset_service_client_client_options_scopes(client_class, transport_cl client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -348,10 +443,12 @@ def test_dataset_service_client_client_options_credentials_file(client_class, tr def test_dataset_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = DatasetServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -364,10 +461,11 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): +def test_create_dataset( + transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -375,11 +473,9 @@ def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.Cr request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_dataset(request) @@ -401,25 +497,24 @@ def test_create_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: client.create_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.CreateDatasetRequest() + @pytest.mark.asyncio -async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.CreateDatasetRequest): +async def test_create_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -427,12 +522,10 @@ async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_dataset(request) @@ -453,20 +546,16 @@ async def test_create_dataset_async_from_dict(): def test_create_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_dataset(request) @@ -477,28 +566,23 @@ def test_create_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_dataset(request) @@ -509,29 +593,21 @@ async def test_create_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -539,47 +615,40 @@ def test_create_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") def test_create_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) @pytest.mark.asyncio async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -587,31 +656,30 @@ async def test_create_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") @pytest.mark.asyncio async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) -def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): +def test_get_dataset( + transport: str = "grpc", request_type=dataset_service.GetDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -619,19 +687,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.get_dataset(request) @@ -646,13 +708,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_dataset_from_dict(): @@ -663,25 +725,24 @@ def test_get_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: client.get_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetDatasetRequest() + @pytest.mark.asyncio -async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetDatasetRequest): +async def test_get_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -689,16 +750,16 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.get_dataset(request) @@ -711,13 +772,13 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=d # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -726,19 +787,15 @@ async def test_get_dataset_async_from_dict(): def test_get_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = dataset.Dataset() client.get_dataset(request) @@ -750,27 +807,20 @@ def test_get_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) await client.get_dataset(request) @@ -782,99 +832,79 @@ async def test_get_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset( - name='name_value', - ) + client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset( - name='name_value', - ) + response = await client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) -def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): +def test_update_dataset( + transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -882,19 +912,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.update_dataset(request) @@ -909,13 +933,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_dataset_from_dict(): @@ -926,25 +950,24 @@ def test_update_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: client.update_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.UpdateDatasetRequest() + @pytest.mark.asyncio -async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.UpdateDatasetRequest): +async def test_update_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -952,16 +975,16 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.update_dataset(request) @@ -974,13 +997,13 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -989,19 +1012,15 @@ async def test_update_dataset_async_from_dict(): def test_update_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = gca_dataset.Dataset() client.update_dataset(request) @@ -1013,27 +1032,22 @@ def test_update_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) await client.update_dataset(request) @@ -1045,29 +1059,24 @@ async def test_update_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] def test_update_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1075,36 +1084,30 @@ def test_update_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() @@ -1112,8 +1115,8 @@ async def test_update_dataset_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1121,31 +1124,30 @@ async def test_update_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): +def test_list_datasets( + transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1153,13 +1155,10 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_datasets(request) @@ -1174,7 +1173,7 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_datasets_from_dict(): @@ -1185,25 +1184,24 @@ def test_list_datasets_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: client.list_datasets() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDatasetsRequest() + @pytest.mark.asyncio -async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDatasetsRequest): +async def test_list_datasets_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1211,13 +1209,13 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_datasets(request) @@ -1230,7 +1228,7 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1239,19 +1237,15 @@ async def test_list_datasets_async_from_dict(): def test_list_datasets_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: call.return_value = dataset_service.ListDatasetsResponse() client.list_datasets(request) @@ -1263,28 +1257,23 @@ def test_list_datasets_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) await client.list_datasets(request) @@ -1295,138 +1284,100 @@ async def test_list_datasets_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_datasets_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_datasets( - parent='parent_value', - ) + client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_datasets_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_datasets( - parent='parent_value', - ) + response = await client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) def test_list_datasets_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_datasets(request={}) @@ -1434,147 +1385,102 @@ def test_list_datasets_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in results) + assert all(isinstance(i, dataset.Dataset) for i in results) + def test_list_datasets_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in responses) + assert all(isinstance(i, dataset.Dataset) for i in responses) + @pytest.mark.asyncio async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_datasets(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): +def test_delete_dataset( + transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1582,11 +1488,9 @@ def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.De request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_dataset(request) @@ -1608,25 +1512,24 @@ def test_delete_dataset_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: client.delete_dataset() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.DeleteDatasetRequest() + @pytest.mark.asyncio -async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.DeleteDatasetRequest): +async def test_delete_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1634,12 +1537,10 @@ async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_dataset(request) @@ -1660,20 +1561,16 @@ async def test_delete_dataset_async_from_dict(): def test_delete_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_dataset(request) @@ -1684,28 +1581,23 @@ def test_delete_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_dataset(request) @@ -1716,101 +1608,81 @@ async def test_delete_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset( - name='name_value', - ) + client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset( - name='name_value', - ) + response = await client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) -def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): +def test_import_data( + transport: str = "grpc", request_type=dataset_service.ImportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1818,11 +1690,9 @@ def test_import_data(transport: str = 'grpc', request_type=dataset_service.Impor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.import_data(request) @@ -1844,25 +1714,24 @@ def test_import_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: client.import_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ImportDataRequest() + @pytest.mark.asyncio -async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ImportDataRequest): +async def test_import_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1870,12 +1739,10 @@ async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.import_data(request) @@ -1896,20 +1763,16 @@ async def test_import_data_async_from_dict(): def test_import_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.import_data(request) @@ -1920,28 +1783,23 @@ def test_import_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.import_data(request) @@ -1952,29 +1810,24 @@ async def test_import_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_import_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -1982,47 +1835,47 @@ def test_import_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] def test_import_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) @pytest.mark.asyncio async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -2030,31 +1883,34 @@ async def test_import_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] @pytest.mark.asyncio async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) -def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): +def test_export_data( + transport: str = "grpc", request_type=dataset_service.ExportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2062,11 +1918,9 @@ def test_export_data(transport: str = 'grpc', request_type=dataset_service.Expor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_data(request) @@ -2088,25 +1942,24 @@ def test_export_data_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: client.export_data() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ExportDataRequest() + @pytest.mark.asyncio -async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ExportDataRequest): +async def test_export_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2114,12 +1967,10 @@ async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_data(request) @@ -2140,20 +1991,16 @@ async def test_export_data_async_from_dict(): def test_export_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_data(request) @@ -2164,28 +2011,23 @@ def test_export_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_data(request) @@ -2196,29 +2038,26 @@ async def test_export_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_export_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2226,47 +2065,53 @@ def test_export_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) def test_export_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) @pytest.mark.asyncio async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2274,31 +2119,38 @@ async def test_export_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) @pytest.mark.asyncio async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) -def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): +def test_list_data_items( + transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2306,13 +2158,10 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_items(request) @@ -2327,7 +2176,7 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_items_from_dict(): @@ -2338,25 +2187,24 @@ def test_list_data_items_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: client.list_data_items() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListDataItemsRequest() + @pytest.mark.asyncio -async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDataItemsRequest): +async def test_list_data_items_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2364,13 +2212,13 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_items(request) @@ -2383,7 +2231,7 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2392,19 +2240,15 @@ async def test_list_data_items_async_from_dict(): def test_list_data_items_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: call.return_value = dataset_service.ListDataItemsResponse() client.list_data_items(request) @@ -2416,28 +2260,23 @@ def test_list_data_items_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) await client.list_data_items(request) @@ -2448,104 +2287,81 @@ async def test_list_data_items_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_items_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items( - parent='parent_value', - ) + client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_items_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items( - parent='parent_value', - ) + response = await client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) def test_list_data_items_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2554,32 +2370,23 @@ def test_list_data_items_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_items(request={}) @@ -2587,18 +2394,14 @@ def test_list_data_items_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in results) + assert all(isinstance(i, data_item.DataItem) for i in results) + def test_list_data_items_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2607,40 +2410,32 @@ def test_list_data_items_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2649,46 +2444,37 @@ async def test_list_data_items_async_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in responses) + assert all(isinstance(i, data_item.DataItem) for i in responses) + @pytest.mark.asyncio async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2697,37 +2483,31 @@ async def test_list_data_items_async_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_data_items(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): +def test_get_annotation_spec( + transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2736,16 +2516,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - + name="name_value", display_name="display_name_value", etag="etag_value", ) response = client.get_annotation_spec(request) @@ -2760,11 +2535,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_annotation_spec_from_dict(): @@ -2775,25 +2550,27 @@ def test_get_annotation_spec_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: client.get_annotation_spec() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.GetAnnotationSpecRequest() + @pytest.mark.asyncio -async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetAnnotationSpecRequest): +async def test_get_annotation_spec_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.GetAnnotationSpecRequest, +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2802,14 +2579,14 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( - name='name_value', - display_name='display_name_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + ) response = await client.get_annotation_spec(request) @@ -2822,11 +2599,11 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -2835,19 +2612,17 @@ async def test_get_annotation_spec_async_from_dict(): def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: call.return_value = annotation_spec.AnnotationSpec() client.get_annotation_spec(request) @@ -2859,28 +2634,25 @@ def test_get_annotation_spec_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) await client.get_annotation_spec(request) @@ -2891,99 +2663,85 @@ async def test_get_annotation_spec_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_annotation_spec_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_annotation_spec( - name='name_value', - ) + client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_annotation_spec( - name='name_value', - ) + response = await client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) -def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): +def test_list_annotations( + transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2991,13 +2749,10 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_annotations(request) @@ -3012,7 +2767,7 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_annotations_from_dict(): @@ -3023,25 +2778,24 @@ def test_list_annotations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: client.list_annotations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.ListAnnotationsRequest() + @pytest.mark.asyncio -async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListAnnotationsRequest): +async def test_list_annotations_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3049,13 +2803,13 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_annotations(request) @@ -3068,7 +2822,7 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3077,19 +2831,15 @@ async def test_list_annotations_async_from_dict(): def test_list_annotations_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: call.return_value = dataset_service.ListAnnotationsResponse() client.list_annotations(request) @@ -3101,28 +2851,23 @@ def test_list_annotations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) await client.list_annotations(request) @@ -3133,104 +2878,81 @@ async def test_list_annotations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_annotations_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations( - parent='parent_value', - ) + client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_annotations_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations( - parent='parent_value', - ) + response = await client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) def test_list_annotations_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3239,32 +2961,23 @@ def test_list_annotations_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_annotations(request={}) @@ -3272,18 +2985,14 @@ def test_list_annotations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in results) + assert all(isinstance(i, annotation.Annotation) for i in results) + def test_list_annotations_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3292,40 +3001,32 @@ def test_list_annotations_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3334,46 +3035,37 @@ async def test_list_annotations_async_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in responses) + assert all(isinstance(i, annotation.Annotation) for i in responses) + @pytest.mark.asyncio async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3382,30 +3074,23 @@ async def test_list_annotations_async_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_annotations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3416,8 +3101,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3436,8 +3120,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3465,13 +3148,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3479,13 +3165,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.DatasetServiceGrpcTransport, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) def test_dataset_service_base_transport_error(): @@ -3493,13 +3174,15 @@ def test_dataset_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_dataset_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3508,17 +3191,17 @@ def test_dataset_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_dataset', - 'get_dataset', - 'update_dataset', - 'list_datasets', - 'delete_dataset', - 'import_data', - 'export_data', - 'list_data_items', - 'get_annotation_spec', - 'list_annotations', - ) + "create_dataset", + "get_dataset", + "update_dataset", + "list_datasets", + "delete_dataset", + "import_data", + "export_data", + "list_data_items", + "get_annotation_spec", + "list_annotations", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3531,23 +3214,28 @@ def test_dataset_service_base_transport(): def test_dataset_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_dataset_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport() @@ -3556,11 +3244,11 @@ def test_dataset_service_base_transport_with_adc(): def test_dataset_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) DatasetServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3568,19 +3256,25 @@ def test_dataset_service_auth_adc(): def test_dataset_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) -def test_dataset_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3589,15 +3283,13 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3612,38 +3304,40 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_host_with_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_dataset_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3651,12 +3345,11 @@ def test_dataset_service_grpc_transport_channel(): def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3665,12 +3358,22 @@ def test_dataset_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3679,7 +3382,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3695,9 +3398,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3711,17 +3412,23 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) -def test_dataset_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3738,9 +3445,7 @@ def test_dataset_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3753,16 +3458,12 @@ def test_dataset_service_transport_channel_mtls_with_adc( def test_dataset_service_grpc_lro_client(): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3770,16 +3471,12 @@ def test_dataset_service_grpc_lro_client(): def test_dataset_service_grpc_lro_async_client(): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3792,19 +3489,26 @@ def test_annotation_path(): data_item = "octopus" annotation = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) - actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + actual = DatasetServiceClient.annotation_path( + project, location, dataset, data_item, annotation + ) assert expected == actual def test_parse_annotation_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - "data_item": "winkle", - "annotation": "nautilus", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", } path = DatasetServiceClient.annotation_path(**expected) @@ -3812,24 +3516,31 @@ def test_parse_annotation_path(): actual = DatasetServiceClient.parse_annotation_path(path) assert expected == actual + def test_annotation_spec_path(): project = "scallop" location = "abalone" dataset = "squid" annotation_spec = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) - actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + actual = DatasetServiceClient.annotation_spec_path( + project, location, dataset, annotation_spec + ) assert expected == actual def test_parse_annotation_spec_path(): expected = { - "project": "whelk", - "location": "octopus", - "dataset": "oyster", - "annotation_spec": "nudibranch", - + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", } path = DatasetServiceClient.annotation_spec_path(**expected) @@ -3837,24 +3548,26 @@ def test_parse_annotation_spec_path(): actual = DatasetServiceClient.parse_annotation_spec_path(path) assert expected == actual + def test_data_item_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" data_item = "nautilus" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) assert expected == actual def test_parse_data_item_path(): expected = { - "project": "scallop", - "location": "abalone", - "dataset": "squid", - "data_item": "clam", - + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", } path = DatasetServiceClient.data_item_path(**expected) @@ -3862,22 +3575,24 @@ def test_parse_data_item_path(): actual = DatasetServiceClient.parse_data_item_path(path) assert expected == actual + def test_dataset_path(): project = "whelk" location = "octopus" dataset = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = DatasetServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", } path = DatasetServiceClient.dataset_path(**expected) @@ -3885,18 +3600,20 @@ def test_parse_dataset_path(): actual = DatasetServiceClient.parse_dataset_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = DatasetServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } path = DatasetServiceClient.common_billing_account_path(**expected) @@ -3904,18 +3621,18 @@ def test_parse_common_billing_account_path(): actual = DatasetServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = DatasetServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = DatasetServiceClient.common_folder_path(**expected) @@ -3923,18 +3640,18 @@ def test_parse_common_folder_path(): actual = DatasetServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = DatasetServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = DatasetServiceClient.common_organization_path(**expected) @@ -3942,18 +3659,18 @@ def test_parse_common_organization_path(): actual = DatasetServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = DatasetServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = DatasetServiceClient.common_project_path(**expected) @@ -3961,20 +3678,22 @@ def test_parse_common_project_path(): actual = DatasetServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = DatasetServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = DatasetServiceClient.common_location_path(**expected) @@ -3986,17 +3705,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = DatasetServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index 47d80619c5..a8ee297c20 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceClient +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceClient, +) from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type @@ -63,7 +67,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -74,36 +82,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - EndpointServiceClient, - EndpointServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] +) def test_endpoint_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - EndpointServiceClient, - EndpointServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] +) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -113,7 +137,7 @@ def test_endpoint_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_client_get_transport_class(): @@ -127,29 +151,44 @@ def test_endpoint_service_client_get_transport_class(): assert transport == transports.EndpointServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) -def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -165,7 +204,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -181,7 +220,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -201,13 +240,15 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -220,26 +261,62 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "true", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "false", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_endpoint_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -262,10 +339,18 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -286,9 +371,14 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -302,16 +392,23 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -324,16 +421,24 @@ def test_endpoint_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -348,10 +453,12 @@ def test_endpoint_service_client_client_options_credentials_file(client_class, t def test_endpoint_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = EndpointServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -364,10 +471,11 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): +def test_create_endpoint( + transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -375,11 +483,9 @@ def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_endpoint(request) @@ -401,25 +507,24 @@ def test_create_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: client.create_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.CreateEndpointRequest() + @pytest.mark.asyncio -async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.CreateEndpointRequest): +async def test_create_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -427,12 +532,10 @@ async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_endpoint(request) @@ -453,20 +556,16 @@ async def test_create_endpoint_async_from_dict(): def test_create_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_endpoint(request) @@ -477,28 +576,23 @@ def test_create_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_endpoint(request) @@ -509,29 +603,21 @@ async def test_create_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -539,47 +625,40 @@ def test_create_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") def test_create_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) @pytest.mark.asyncio async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -587,31 +666,30 @@ async def test_create_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") @pytest.mark.asyncio async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) -def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): +def test_get_endpoint( + transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -619,19 +697,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.get_endpoint(request) @@ -646,13 +718,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_endpoint_from_dict(): @@ -663,25 +735,24 @@ def test_get_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: client.get_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.GetEndpointRequest() + @pytest.mark.asyncio -async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.GetEndpointRequest): +async def test_get_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -689,16 +760,16 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.get_endpoint(request) @@ -711,13 +782,13 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -726,19 +797,15 @@ async def test_get_endpoint_async_from_dict(): def test_get_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = endpoint.Endpoint() client.get_endpoint(request) @@ -750,27 +817,20 @@ def test_get_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) await client.get_endpoint(request) @@ -782,99 +842,79 @@ async def test_get_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_endpoint( - name='name_value', - ) + client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_endpoint( - name='name_value', - ) + response = await client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) -def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): +def test_list_endpoints( + transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -882,13 +922,10 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_endpoints(request) @@ -903,7 +940,7 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_endpoints_from_dict(): @@ -914,25 +951,24 @@ def test_list_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: client.list_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.ListEndpointsRequest() + @pytest.mark.asyncio -async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.ListEndpointsRequest): +async def test_list_endpoints_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -940,13 +976,13 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_endpoints(request) @@ -959,7 +995,7 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -968,19 +1004,15 @@ async def test_list_endpoints_async_from_dict(): def test_list_endpoints_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: call.return_value = endpoint_service.ListEndpointsResponse() client.list_endpoints(request) @@ -992,28 +1024,23 @@ def test_list_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) await client.list_endpoints(request) @@ -1024,104 +1051,81 @@ async def test_list_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_endpoints_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_endpoints( - parent='parent_value', - ) + client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_endpoints_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_endpoints( - parent='parent_value', - ) + response = await client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) def test_list_endpoints_pager(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1130,32 +1134,23 @@ def test_list_endpoints_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_endpoints(request={}) @@ -1163,18 +1158,14 @@ def test_list_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in results) + assert all(isinstance(i, endpoint.Endpoint) for i in results) + def test_list_endpoints_pages(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1183,40 +1174,32 @@ def test_list_endpoints_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1225,46 +1208,37 @@ async def test_list_endpoints_async_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in responses) + assert all(isinstance(i, endpoint.Endpoint) for i in responses) + @pytest.mark.asyncio async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1273,37 +1247,31 @@ async def test_list_endpoints_async_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): +def test_update_endpoint( + transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1311,19 +1279,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.update_endpoint(request) @@ -1338,13 +1300,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_endpoint_from_dict(): @@ -1355,25 +1317,24 @@ def test_update_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: client.update_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UpdateEndpointRequest() + @pytest.mark.asyncio -async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UpdateEndpointRequest): +async def test_update_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1381,16 +1342,16 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.update_endpoint(request) @@ -1403,13 +1364,13 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -1418,19 +1379,15 @@ async def test_update_endpoint_async_from_dict(): def test_update_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: call.return_value = gca_endpoint.Endpoint() client.update_endpoint(request) @@ -1442,28 +1399,25 @@ def test_update_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) await client.update_endpoint(request) @@ -1474,29 +1428,24 @@ async def test_update_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] def test_update_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1504,45 +1453,41 @@ def test_update_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1550,31 +1495,30 @@ async def test_update_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): +def test_delete_endpoint( + transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1582,11 +1526,9 @@ def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_endpoint(request) @@ -1608,25 +1550,24 @@ def test_delete_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: client.delete_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeleteEndpointRequest() + @pytest.mark.asyncio -async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeleteEndpointRequest): +async def test_delete_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1634,12 +1575,10 @@ async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_endpoint(request) @@ -1660,20 +1599,16 @@ async def test_delete_endpoint_async_from_dict(): def test_delete_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_endpoint(request) @@ -1684,28 +1619,23 @@ def test_delete_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_endpoint(request) @@ -1716,101 +1646,81 @@ async def test_delete_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_endpoint( - name='name_value', - ) + client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_endpoint( - name='name_value', - ) + response = await client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) -def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): +def test_deploy_model( + transport: str = "grpc", request_type=endpoint_service.DeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1818,11 +1728,9 @@ def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.Dep request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.deploy_model(request) @@ -1844,25 +1752,24 @@ def test_deploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: client.deploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.DeployModelRequest() + @pytest.mark.asyncio -async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeployModelRequest): +async def test_deploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1870,12 +1777,10 @@ async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.deploy_model(request) @@ -1896,20 +1801,16 @@ async def test_deploy_model_async_from_dict(): def test_deploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.deploy_model(request) @@ -1920,28 +1821,23 @@ def test_deploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.deploy_model(request) @@ -1952,30 +1848,29 @@ async def test_deploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_deploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -1983,51 +1878,63 @@ def test_deploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_deploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2035,34 +1942,45 @@ async def test_deploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) -def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): +def test_undeploy_model( + transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2070,11 +1988,9 @@ def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.U request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.undeploy_model(request) @@ -2096,25 +2012,24 @@ def test_undeploy_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: client.undeploy_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UndeployModelRequest() + @pytest.mark.asyncio -async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UndeployModelRequest): +async def test_undeploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2122,12 +2037,10 @@ async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.undeploy_model(request) @@ -2148,20 +2061,16 @@ async def test_undeploy_model_async_from_dict(): def test_undeploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.undeploy_model(request) @@ -2172,28 +2081,23 @@ def test_undeploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.undeploy_model(request) @@ -2204,30 +2108,23 @@ async def test_undeploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_undeploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2235,51 +2132,45 @@ def test_undeploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_undeploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2287,27 +2178,25 @@ async def test_undeploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @@ -2318,8 +2207,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2338,8 +2226,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -2367,13 +2254,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2381,13 +2271,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.EndpointServiceGrpcTransport, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) def test_endpoint_service_base_transport_error(): @@ -2395,13 +2280,15 @@ def test_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2410,14 +2297,14 @@ def test_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_endpoint', - 'get_endpoint', - 'list_endpoints', - 'update_endpoint', - 'delete_endpoint', - 'deploy_model', - 'undeploy_model', - ) + "create_endpoint", + "get_endpoint", + "list_endpoints", + "update_endpoint", + "delete_endpoint", + "deploy_model", + "undeploy_model", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2430,23 +2317,28 @@ def test_endpoint_service_base_transport(): def test_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport() @@ -2455,11 +2347,11 @@ def test_endpoint_service_base_transport_with_adc(): def test_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) EndpointServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -2467,19 +2359,25 @@ def test_endpoint_service_auth_adc(): def test_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) -def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -2488,15 +2386,13 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2511,38 +2407,40 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_host_with_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_endpoint_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2550,12 +2448,11 @@ def test_endpoint_service_grpc_transport_channel(): def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2564,12 +2461,22 @@ def test_endpoint_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2578,7 +2485,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2594,9 +2501,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2610,17 +2515,23 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) -def test_endpoint_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2637,9 +2548,7 @@ def test_endpoint_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2652,16 +2561,12 @@ def test_endpoint_service_transport_channel_mtls_with_adc( def test_endpoint_service_grpc_lro_client(): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2669,16 +2574,12 @@ def test_endpoint_service_grpc_lro_client(): def test_endpoint_service_grpc_lro_async_client(): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2689,17 +2590,18 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = EndpointServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = EndpointServiceClient.endpoint_path(**expected) @@ -2707,22 +2609,24 @@ def test_parse_endpoint_path(): actual = EndpointServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = EndpointServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = EndpointServiceClient.model_path(**expected) @@ -2730,18 +2634,20 @@ def test_parse_model_path(): actual = EndpointServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = EndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2749,18 +2655,18 @@ def test_parse_common_billing_account_path(): actual = EndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = EndpointServiceClient.common_folder_path(**expected) @@ -2768,18 +2674,18 @@ def test_parse_common_folder_path(): actual = EndpointServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = EndpointServiceClient.common_organization_path(**expected) @@ -2787,18 +2693,18 @@ def test_parse_common_organization_path(): actual = EndpointServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = EndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = EndpointServiceClient.common_project_path(**expected) @@ -2806,20 +2712,22 @@ def test_parse_common_project_path(): actual = EndpointServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = EndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = EndpointServiceClient.common_location_path(**expected) @@ -2831,17 +2739,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = EndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index 7593ba87a6..c8e506d54b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -41,24 +41,32 @@ from google.cloud.aiplatform_v1beta1.services.job_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import explanation_metadata from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.cloud.aiplatform_v1beta1.types import model_monitoring from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import study @@ -81,7 +89,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -92,36 +104,45 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [ - JobServiceClient, - JobServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) def test_job_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - JobServiceClient, - JobServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -131,7 +152,7 @@ def test_job_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_client_get_transport_class(): @@ -145,29 +166,42 @@ def test_job_service_client_get_transport_class(): assert transport == transports.JobServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) -def test_job_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -183,7 +217,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -199,7 +233,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -219,13 +253,15 @@ def test_job_service_client_client_options(client_class, transport_class, transp client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -238,26 +274,50 @@ def test_job_service_client_client_options(client_class, transport_class, transp client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_job_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -280,10 +340,18 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -304,9 +372,14 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -320,16 +393,23 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -342,16 +422,24 @@ def test_job_service_client_client_options_scopes(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -366,11 +454,11 @@ def test_job_service_client_client_options_credentials_file(client_class, transp def test_job_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = JobServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -382,10 +470,11 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): +def test_create_custom_job( + transport: str = "grpc", request_type=job_service.CreateCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -394,16 +483,13 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_custom_job(request) @@ -418,9 +504,9 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -433,25 +519,26 @@ def test_create_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: client.create_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateCustomJobRequest() + @pytest.mark.asyncio -async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateCustomJobRequest): +async def test_create_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -460,14 +547,16 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_custom_job(request) @@ -480,9 +569,9 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_ # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -493,19 +582,17 @@ async def test_create_custom_job_async_from_dict(): def test_create_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: call.return_value = gca_custom_job.CustomJob() client.create_custom_job(request) @@ -517,28 +604,25 @@ def test_create_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) await client.create_custom_job(request) @@ -549,29 +633,24 @@ async def test_create_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -579,45 +658,43 @@ def test_create_custom_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") def test_create_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -625,31 +702,30 @@ async def test_create_custom_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") @pytest.mark.asyncio async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) -def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): +def test_get_custom_job( + transport: str = "grpc", request_type=job_service.GetCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -657,17 +733,12 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_custom_job(request) @@ -682,9 +753,9 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -697,25 +768,24 @@ def test_get_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: client.get_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetCustomJobRequest() + @pytest.mark.asyncio -async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetCustomJobRequest): +async def test_get_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -723,15 +793,15 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_custom_job(request) @@ -744,9 +814,9 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -757,19 +827,15 @@ async def test_get_custom_job_async_from_dict(): def test_get_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: call.return_value = custom_job.CustomJob() client.get_custom_job(request) @@ -781,28 +847,23 @@ def test_get_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) await client.get_custom_job(request) @@ -813,99 +874,81 @@ async def test_get_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_custom_job( - name='name_value', - ) + client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_custom_job( - name='name_value', - ) + response = await client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) -def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): +def test_list_custom_jobs( + transport: str = "grpc", request_type=job_service.ListCustomJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -913,13 +956,10 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_custom_jobs(request) @@ -934,7 +974,7 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_custom_jobs_from_dict(): @@ -945,25 +985,24 @@ def test_list_custom_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: client.list_custom_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListCustomJobsRequest() + @pytest.mark.asyncio -async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListCustomJobsRequest): +async def test_list_custom_jobs_async( + transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -971,13 +1010,11 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_custom_jobs(request) @@ -990,7 +1027,7 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -999,19 +1036,15 @@ async def test_list_custom_jobs_async_from_dict(): def test_list_custom_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: call.return_value = job_service.ListCustomJobsResponse() client.list_custom_jobs(request) @@ -1023,28 +1056,23 @@ def test_list_custom_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) await client.list_custom_jobs(request) @@ -1055,104 +1083,81 @@ async def test_list_custom_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_custom_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_custom_jobs( - parent='parent_value', - ) + client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_custom_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_custom_jobs( - parent='parent_value', - ) + response = await client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) def test_list_custom_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1161,32 +1166,21 @@ def test_list_custom_jobs_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_custom_jobs(request={}) @@ -1194,18 +1188,14 @@ def test_list_custom_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in results) + assert all(isinstance(i, custom_job.CustomJob) for i in results) + def test_list_custom_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1214,40 +1204,30 @@ def test_list_custom_jobs_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1256,46 +1236,35 @@ async def test_list_custom_jobs_async_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in responses) + assert all(isinstance(i, custom_job.CustomJob) for i in responses) + @pytest.mark.asyncio async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1304,37 +1273,29 @@ async def test_list_custom_jobs_async_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_custom_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): +def test_delete_custom_job( + transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1343,10 +1304,10 @@ def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.Del # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_custom_job(request) @@ -1368,25 +1329,26 @@ def test_delete_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: client.delete_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteCustomJobRequest() + @pytest.mark.asyncio -async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteCustomJobRequest): +async def test_delete_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1395,11 +1357,11 @@ async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_custom_job(request) @@ -1420,20 +1382,18 @@ async def test_delete_custom_job_async_from_dict(): def test_delete_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_custom_job(request) @@ -1444,28 +1404,25 @@ def test_delete_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_custom_job(request) @@ -1476,101 +1433,85 @@ async def test_delete_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_custom_job( - name='name_value', - ) + client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_custom_job( - name='name_value', - ) + response = await client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) -def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): +def test_cancel_custom_job( + transport: str = "grpc", request_type=job_service.CancelCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1579,8 +1520,8 @@ def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.Can # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1604,25 +1545,26 @@ def test_cancel_custom_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: client.cancel_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelCustomJobRequest() + @pytest.mark.asyncio -async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelCustomJobRequest): +async def test_cancel_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1631,8 +1573,8 @@ async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1654,19 +1596,17 @@ async def test_cancel_custom_job_async_from_dict(): def test_cancel_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = None client.cancel_custom_job(request) @@ -1678,27 +1618,22 @@ def test_cancel_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_custom_job(request) @@ -1710,99 +1645,83 @@ async def test_cancel_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_custom_job( - name='name_value', - ) + client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_custom_job( - name='name_value', - ) + response = await client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) -def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): +def test_create_data_labeling_job( + transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1811,28 +1730,19 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.create_data_labeling_job(request) @@ -1847,23 +1757,23 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_create_data_labeling_job_from_dict(): @@ -1874,25 +1784,27 @@ def test_create_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: client.create_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateDataLabelingJobRequest): +async def test_create_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1901,20 +1813,22 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.create_data_labeling_job(request) @@ -1927,23 +1841,23 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] @pytest.mark.asyncio @@ -1952,19 +1866,17 @@ async def test_create_data_labeling_job_async_from_dict(): def test_create_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: call.return_value = gca_data_labeling_job.DataLabelingJob() client.create_data_labeling_job(request) @@ -1976,28 +1888,25 @@ def test_create_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) await client.create_data_labeling_job(request) @@ -2008,29 +1917,24 @@ async def test_create_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -2038,45 +1942,45 @@ def test_create_data_labeling_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -2084,31 +1988,32 @@ async def test_create_data_labeling_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) -def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): +def test_get_data_labeling_job( + transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2117,28 +2022,19 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.get_data_labeling_job(request) @@ -2153,23 +2049,23 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_get_data_labeling_job_from_dict(): @@ -2180,25 +2076,26 @@ def test_get_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: client.get_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetDataLabelingJobRequest): +async def test_get_data_labeling_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2207,20 +2104,22 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.get_data_labeling_job(request) @@ -2233,23 +2132,23 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] @pytest.mark.asyncio @@ -2258,19 +2157,17 @@ async def test_get_data_labeling_job_async_from_dict(): def test_get_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: call.return_value = data_labeling_job.DataLabelingJob() client.get_data_labeling_job(request) @@ -2282,28 +2179,25 @@ def test_get_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) await client.get_data_labeling_job(request) @@ -2314,99 +2208,85 @@ async def test_get_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_data_labeling_job( - name='name_value', - ) + client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_data_labeling_job( - name='name_value', - ) + response = await client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) -def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): +def test_list_data_labeling_jobs( + transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2415,12 +2295,11 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_labeling_jobs(request) @@ -2435,7 +2314,7 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_labeling_jobs_from_dict(): @@ -2446,25 +2325,27 @@ def test_list_data_labeling_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: client.list_data_labeling_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListDataLabelingJobsRequest() + @pytest.mark.asyncio -async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListDataLabelingJobsRequest): +async def test_list_data_labeling_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListDataLabelingJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2473,12 +2354,14 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_labeling_jobs(request) @@ -2491,7 +2374,7 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', re # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2500,19 +2383,17 @@ async def test_list_data_labeling_jobs_async_from_dict(): def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: call.return_value = job_service.ListDataLabelingJobsResponse() client.list_data_labeling_jobs(request) @@ -2524,28 +2405,25 @@ def test_list_data_labeling_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) await client.list_data_labeling_jobs(request) @@ -2556,104 +2434,87 @@ async def test_list_data_labeling_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_labeling_jobs( - parent='parent_value', - ) + client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs( - parent='parent_value', - ) + response = await client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) def test_list_data_labeling_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2662,17 +2523,14 @@ def test_list_data_labeling_jobs_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2685,9 +2543,7 @@ def test_list_data_labeling_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_labeling_jobs(request={}) @@ -2695,18 +2551,16 @@ def test_list_data_labeling_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in results) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) + def test_list_data_labeling_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2715,17 +2569,14 @@ def test_list_data_labeling_jobs_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2736,19 +2587,20 @@ def test_list_data_labeling_jobs_pages(): RuntimeError, ) pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2757,17 +2609,14 @@ async def test_list_data_labeling_jobs_async_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2778,25 +2627,25 @@ async def test_list_data_labeling_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in responses) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2805,17 +2654,14 @@ async def test_list_data_labeling_jobs_async_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2828,14 +2674,15 @@ async def test_list_data_labeling_jobs_async_pages(): pages = [] async for page_ in (await client.list_data_labeling_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): +def test_delete_data_labeling_job( + transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2844,10 +2691,10 @@ def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_data_labeling_job(request) @@ -2869,25 +2716,27 @@ def test_delete_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: client.delete_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteDataLabelingJobRequest): +async def test_delete_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2896,11 +2745,11 @@ async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_data_labeling_job(request) @@ -2921,20 +2770,18 @@ async def test_delete_data_labeling_job_async_from_dict(): def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_data_labeling_job(request) @@ -2945,28 +2792,25 @@ def test_delete_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_data_labeling_job(request) @@ -2977,101 +2821,85 @@ async def test_delete_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_data_labeling_job( - name='name_value', - ) + client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_data_labeling_job( - name='name_value', - ) + response = await client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) -def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): +def test_cancel_data_labeling_job( + transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3080,8 +2908,8 @@ def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -3105,25 +2933,27 @@ def test_cancel_data_labeling_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: client.cancel_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelDataLabelingJobRequest() + @pytest.mark.asyncio -async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelDataLabelingJobRequest): +async def test_cancel_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3132,8 +2962,8 @@ async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -3155,19 +2985,17 @@ async def test_cancel_data_labeling_job_async_from_dict(): def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = None client.cancel_data_labeling_job(request) @@ -3179,27 +3007,22 @@ def test_cancel_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_data_labeling_job(request) @@ -3211,99 +3034,84 @@ async def test_cancel_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_data_labeling_job( - name='name_value', - ) + client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job( - name='name_value', - ) + response = await client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) -def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): +def test_create_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3312,22 +3120,16 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_hyperparameter_tuning_job(request) @@ -3342,9 +3144,9 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3363,25 +3165,27 @@ def test_create_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: client.create_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateHyperparameterTuningJobRequest): +async def test_create_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3390,17 +3194,19 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_hyperparameter_tuning_job(request) @@ -3413,9 +3219,9 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3432,19 +3238,17 @@ async def test_create_hyperparameter_tuning_job_async_from_dict(): def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() client.create_hyperparameter_tuning_job(request) @@ -3456,28 +3260,25 @@ def test_create_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.create_hyperparameter_tuning_job(request) @@ -3488,29 +3289,26 @@ async def test_create_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3518,45 +3316,51 @@ def test_create_hyperparameter_tuning_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3564,31 +3368,36 @@ async def test_create_hyperparameter_tuning_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) -def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): +def test_get_hyperparameter_tuning_job( + transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3597,22 +3406,16 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_hyperparameter_tuning_job(request) @@ -3627,9 +3430,9 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3648,25 +3451,27 @@ def test_get_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: client.get_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetHyperparameterTuningJobRequest): +async def test_get_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3675,17 +3480,19 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_hyperparameter_tuning_job(request) @@ -3698,9 +3505,9 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3717,19 +3524,17 @@ async def test_get_hyperparameter_tuning_job_async_from_dict(): def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() client.get_hyperparameter_tuning_job(request) @@ -3741,28 +3546,25 @@ def test_get_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.get_hyperparameter_tuning_job(request) @@ -3773,99 +3575,86 @@ async def test_get_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job( - name='name_value', - ) + client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) -def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): +def test_list_hyperparameter_tuning_jobs( + transport: str = "grpc", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3874,12 +3663,11 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_hyperparameter_tuning_jobs(request) @@ -3894,7 +3682,7 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_hyperparameter_tuning_jobs_from_dict(): @@ -3905,25 +3693,27 @@ def test_list_hyperparameter_tuning_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: client.list_hyperparameter_tuning_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListHyperparameterTuningJobsRequest): +async def test_list_hyperparameter_tuning_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3932,12 +3722,14 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_hyperparameter_tuning_jobs(request) @@ -3950,7 +3742,7 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3959,19 +3751,17 @@ async def test_list_hyperparameter_tuning_jobs_async_from_dict(): def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: call.return_value = job_service.ListHyperparameterTuningJobsResponse() client.list_hyperparameter_tuning_jobs(request) @@ -3983,28 +3773,25 @@ def test_list_hyperparameter_tuning_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) await client.list_hyperparameter_tuning_jobs(request) @@ -4015,104 +3802,87 @@ async def test_list_hyperparameter_tuning_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4121,17 +3891,16 @@ def test_list_hyperparameter_tuning_jobs_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4144,9 +3913,7 @@ def test_list_hyperparameter_tuning_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_hyperparameter_tuning_jobs(request={}) @@ -4154,18 +3921,19 @@ def test_list_hyperparameter_tuning_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results + ) + def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4174,17 +3942,16 @@ def test_list_hyperparameter_tuning_jobs_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4195,19 +3962,20 @@ def test_list_hyperparameter_tuning_jobs_pages(): RuntimeError, ) pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4216,17 +3984,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4237,25 +4004,28 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses + ) + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4264,17 +4034,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4285,16 +4054,20 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: + async for page_ in ( + await client.list_hyperparameter_tuning_jobs(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): +def test_delete_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4303,10 +4076,10 @@ def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_hyperparameter_tuning_job(request) @@ -4328,25 +4101,27 @@ def test_delete_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: client.delete_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteHyperparameterTuningJobRequest): +async def test_delete_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4355,11 +4130,11 @@ async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_hyperparameter_tuning_job(request) @@ -4380,20 +4155,18 @@ async def test_delete_hyperparameter_tuning_job_async_from_dict(): def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_hyperparameter_tuning_job(request) @@ -4404,28 +4177,25 @@ def test_delete_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_hyperparameter_tuning_job(request) @@ -4436,101 +4206,86 @@ async def test_delete_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job( - name='name_value', - ) + client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) -def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): +def test_cancel_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4539,8 +4294,8 @@ def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -4564,25 +4319,27 @@ def test_cancel_hyperparameter_tuning_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: client.cancel_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + @pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelHyperparameterTuningJobRequest): +async def test_cancel_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4591,8 +4348,8 @@ async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -4614,19 +4371,17 @@ async def test_cancel_hyperparameter_tuning_job_async_from_dict(): def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = None client.cancel_hyperparameter_tuning_job(request) @@ -4638,27 +4393,22 @@ def test_cancel_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_hyperparameter_tuning_job(request) @@ -4670,99 +4420,83 @@ async def test_cancel_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) -def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): +def test_create_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4771,20 +4505,15 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_batch_prediction_job(request) @@ -4799,11 +4528,11 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4818,25 +4547,27 @@ def test_create_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: client.create_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateBatchPredictionJobRequest): +async def test_create_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4845,16 +4576,18 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_batch_prediction_job(request) @@ -4867,11 +4600,11 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4884,19 +4617,17 @@ async def test_create_batch_prediction_job_async_from_dict(): def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: call.return_value = gca_batch_prediction_job.BatchPredictionJob() client.create_batch_prediction_job(request) @@ -4908,28 +4639,25 @@ def test_create_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) await client.create_batch_prediction_job(request) @@ -4940,29 +4668,26 @@ async def test_create_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -4970,45 +4695,51 @@ def test_create_batch_prediction_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -5016,31 +4747,36 @@ async def test_create_batch_prediction_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) -def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): +def test_get_batch_prediction_job( + transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5049,20 +4785,15 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_batch_prediction_job(request) @@ -5077,11 +4808,11 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -5096,25 +4827,27 @@ def test_get_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: client.get_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetBatchPredictionJobRequest): +async def test_get_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5123,16 +4856,18 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_batch_prediction_job(request) @@ -5145,11 +4880,11 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -5162,19 +4897,17 @@ async def test_get_batch_prediction_job_async_from_dict(): def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: call.return_value = batch_prediction_job.BatchPredictionJob() client.get_batch_prediction_job(request) @@ -5186,28 +4919,25 @@ def test_get_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) await client.get_batch_prediction_job(request) @@ -5218,99 +4948,85 @@ async def test_get_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_batch_prediction_job( - name='name_value', - ) + client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_batch_prediction_job( - name='name_value', - ) + response = await client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) -def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): +def test_list_batch_prediction_jobs( + transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5319,12 +5035,11 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_batch_prediction_jobs(request) @@ -5339,7 +5054,7 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_batch_prediction_jobs_from_dict(): @@ -5350,25 +5065,27 @@ def test_list_batch_prediction_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: client.list_batch_prediction_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListBatchPredictionJobsRequest() + @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListBatchPredictionJobsRequest): +async def test_list_batch_prediction_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListBatchPredictionJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5377,12 +5094,14 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_batch_prediction_jobs(request) @@ -5395,7 +5114,7 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -5404,19 +5123,17 @@ async def test_list_batch_prediction_jobs_async_from_dict(): def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: call.return_value = job_service.ListBatchPredictionJobsResponse() client.list_batch_prediction_jobs(request) @@ -5428,28 +5145,25 @@ def test_list_batch_prediction_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) await client.list_batch_prediction_jobs(request) @@ -5460,104 +5174,87 @@ async def test_list_batch_prediction_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_batch_prediction_jobs( - parent='parent_value', - ) + client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs( - parent='parent_value', - ) + response = await client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5566,17 +5263,14 @@ def test_list_batch_prediction_jobs_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5589,9 +5283,7 @@ def test_list_batch_prediction_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_batch_prediction_jobs(request={}) @@ -5599,18 +5291,18 @@ def test_list_batch_prediction_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in results) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results + ) + def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5619,17 +5311,14 @@ def test_list_batch_prediction_jobs_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5640,19 +5329,20 @@ def test_list_batch_prediction_jobs_pages(): RuntimeError, ) pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5661,17 +5351,14 @@ async def test_list_batch_prediction_jobs_async_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5682,25 +5369,27 @@ async def test_list_batch_prediction_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in responses) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses + ) + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5709,17 +5398,14 @@ async def test_list_batch_prediction_jobs_async_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5732,14 +5418,15 @@ async def test_list_batch_prediction_jobs_async_pages(): pages = [] async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): +def test_delete_batch_prediction_job( + transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5748,10 +5435,10 @@ def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_batch_prediction_job(request) @@ -5773,25 +5460,27 @@ def test_delete_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: client.delete_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteBatchPredictionJobRequest): +async def test_delete_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5800,11 +5489,11 @@ async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_batch_prediction_job(request) @@ -5825,20 +5514,18 @@ async def test_delete_batch_prediction_job_async_from_dict(): def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_batch_prediction_job(request) @@ -5849,28 +5536,25 @@ def test_delete_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_batch_prediction_job(request) @@ -5881,101 +5565,85 @@ async def test_delete_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_batch_prediction_job( - name='name_value', - ) + client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job( - name='name_value', - ) + response = await client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) -def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): +def test_cancel_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5984,8 +5652,8 @@ def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -6009,25 +5677,27 @@ def test_cancel_batch_prediction_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: client.cancel_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CancelBatchPredictionJobRequest() + @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelBatchPredictionJobRequest): +async def test_cancel_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6036,8 +5706,8 @@ async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -6059,19 +5729,17 @@ async def test_cancel_batch_prediction_job_async_from_dict(): def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = None client.cancel_batch_prediction_job(request) @@ -6083,27 +5751,22 @@ def test_cancel_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_batch_prediction_job(request) @@ -6115,99 +5778,84 @@ async def test_cancel_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_batch_prediction_job( - name='name_value', - ) + client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job( - name='name_value', - ) + response = await client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) -def test_create_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.CreateModelDeploymentMonitoringJobRequest): +def test_create_model_deployment_monitoring_job( + transport: str = "grpc", + request_type=job_service.CreateModelDeploymentMonitoringJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6216,24 +5864,17 @@ def test_create_model_deployment_monitoring_job(transport: str = 'grpc', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.create_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( - name='name_value', - - display_name='display_name_value', - - endpoint='endpoint_value', - + name="name_value", + display_name="display_name_value", + endpoint="endpoint_value", state=job_state.JobState.JOB_STATE_QUEUED, - schedule_state=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, - - predict_instance_schema_uri='predict_instance_schema_uri_value', - - analysis_instance_schema_uri='analysis_instance_schema_uri_value', - + predict_instance_schema_uri="predict_instance_schema_uri_value", + analysis_instance_schema_uri="analysis_instance_schema_uri_value", ) response = client.create_model_deployment_monitoring_job(request) @@ -6246,21 +5887,26 @@ def test_create_model_deployment_monitoring_job(transport: str = 'grpc', request # Establish that the response is the type that we expect. - assert isinstance(response, gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + assert isinstance( + response, gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob + ) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.endpoint == 'endpoint_value' + assert response.endpoint == "endpoint_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.schedule_state == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + assert ( + response.schedule_state + == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + ) - assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + assert response.predict_instance_schema_uri == "predict_instance_schema_uri_value" - assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + assert response.analysis_instance_schema_uri == "analysis_instance_schema_uri_value" def test_create_model_deployment_monitoring_job_from_dict(): @@ -6271,25 +5917,27 @@ def test_create_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.create_model_deployment_monitoring_job), "__call__" + ) as call: client.create_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.CreateModelDeploymentMonitoringJobRequest() + @pytest.mark.asyncio -async def test_create_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateModelDeploymentMonitoringJobRequest): +async def test_create_model_deployment_monitoring_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateModelDeploymentMonitoringJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6298,18 +5946,20 @@ async def test_create_model_deployment_monitoring_job_async(transport: str = 'gr # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.create_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( - name='name_value', - display_name='display_name_value', - endpoint='endpoint_value', - state=job_state.JobState.JOB_STATE_QUEUED, - schedule_state=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, - predict_instance_schema_uri='predict_instance_schema_uri_value', - analysis_instance_schema_uri='analysis_instance_schema_uri_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value", + display_name="display_name_value", + endpoint="endpoint_value", + state=job_state.JobState.JOB_STATE_QUEUED, + schedule_state=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, + predict_instance_schema_uri="predict_instance_schema_uri_value", + analysis_instance_schema_uri="analysis_instance_schema_uri_value", + ) + ) response = await client.create_model_deployment_monitoring_job(request) @@ -6320,21 +5970,26 @@ async def test_create_model_deployment_monitoring_job_async(transport: str = 'gr assert args[0] == job_service.CreateModelDeploymentMonitoringJobRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + assert isinstance( + response, gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob + ) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.endpoint == 'endpoint_value' + assert response.endpoint == "endpoint_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.schedule_state == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + assert ( + response.schedule_state + == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + ) - assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + assert response.predict_instance_schema_uri == "predict_instance_schema_uri_value" - assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + assert response.analysis_instance_schema_uri == "analysis_instance_schema_uri_value" @pytest.mark.asyncio @@ -6343,20 +5998,20 @@ async def test_create_model_deployment_monitoring_job_async_from_dict(): def test_create_model_deployment_monitoring_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateModelDeploymentMonitoringJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + type(client.transport.create_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = ( + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) client.create_model_deployment_monitoring_job(request) @@ -6367,28 +6022,25 @@ def test_create_model_deployment_monitoring_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_model_deployment_monitoring_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateModelDeploymentMonitoringJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + type(client.transport.create_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) await client.create_model_deployment_monitoring_job(request) @@ -6399,29 +6051,28 @@ async def test_create_model_deployment_monitoring_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_model_deployment_monitoring_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.create_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + call.return_value = ( + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_model_deployment_monitoring_job( - parent='parent_value', - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + parent="parent_value", + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -6429,45 +6080,53 @@ def test_create_model_deployment_monitoring_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + assert args[ + 0 + ].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ) def test_create_model_deployment_monitoring_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_model_deployment_monitoring_job( job_service.CreateModelDeploymentMonitoringJobRequest(), - parent='parent_value', - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + parent="parent_value", + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_model_deployment_monitoring_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.create_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + call.return_value = ( + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_model_deployment_monitoring_job( - parent='parent_value', - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + parent="parent_value", + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -6475,31 +6134,37 @@ async def test_create_model_deployment_monitoring_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + assert args[ + 0 + ].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_model_deployment_monitoring_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_model_deployment_monitoring_job( job_service.CreateModelDeploymentMonitoringJobRequest(), - parent='parent_value', - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), + parent="parent_value", + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), ) -def test_search_model_deployment_monitoring_stats_anomalies(transport: str = 'grpc', request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest): +def test_search_model_deployment_monitoring_stats_anomalies( + transport: str = "grpc", + request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6508,12 +6173,12 @@ def test_search_model_deployment_monitoring_stats_anomalies(transport: str = 'gr # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_model_deployment_monitoring_stats_anomalies(request) @@ -6522,13 +6187,18 @@ def test_search_model_deployment_monitoring_stats_anomalies(transport: str = 'gr assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + assert ( + args[0] + == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + ) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager) + assert isinstance( + response, pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager + ) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_model_deployment_monitoring_stats_anomalies_from_dict(): @@ -6539,25 +6209,31 @@ def test_search_model_deployment_monitoring_stats_anomalies_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: client.search_model_deployment_monitoring_stats_anomalies() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + assert ( + args[0] + == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + ) + @pytest.mark.asyncio -async def test_search_model_deployment_monitoring_stats_anomalies_async(transport: str = 'grpc_asyncio', request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest): +async def test_search_model_deployment_monitoring_stats_anomalies_async( + transport: str = "grpc_asyncio", + request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6566,47 +6242,60 @@ async def test_search_model_deployment_monitoring_stats_anomalies_async(transpor # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( + next_page_token="next_page_token_value", + ) + ) - response = await client.search_model_deployment_monitoring_stats_anomalies(request) + response = await client.search_model_deployment_monitoring_stats_anomalies( + request + ) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + assert ( + args[0] + == job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() + ) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager) + assert isinstance( + response, pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager + ) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_async_from_dict(): - await test_search_model_deployment_monitoring_stats_anomalies_async(request_type=dict) + await test_search_model_deployment_monitoring_stats_anomalies_async( + request_type=dict + ) def test_search_model_deployment_monitoring_stats_anomalies_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() - request.model_deployment_monitoring_job = 'model_deployment_monitoring_job/value' + request.model_deployment_monitoring_job = "model_deployment_monitoring_job/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: - call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: + call.return_value = ( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + ) client.search_model_deployment_monitoring_stats_anomalies(request) @@ -6618,27 +6307,28 @@ def test_search_model_deployment_monitoring_stats_anomalies_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model_deployment_monitoring_job=model_deployment_monitoring_job/value', - ) in kw['metadata'] + "x-goog-request-params", + "model_deployment_monitoring_job=model_deployment_monitoring_job/value", + ) in kw["metadata"] @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest() - request.model_deployment_monitoring_job = 'model_deployment_monitoring_job/value' + request.model_deployment_monitoring_job = "model_deployment_monitoring_job/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse()) + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + ) await client.search_model_deployment_monitoring_stats_anomalies(request) @@ -6650,28 +6340,29 @@ async def test_search_model_deployment_monitoring_stats_anomalies_field_headers_ # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model_deployment_monitoring_job=model_deployment_monitoring_job/value', - ) in kw['metadata'] + "x-goog-request-params", + "model_deployment_monitoring_job=model_deployment_monitoring_job/value", + ) in kw["metadata"] def test_search_model_deployment_monitoring_stats_anomalies_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: # Designate an appropriate return value for the call. - call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + call.return_value = ( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.search_model_deployment_monitoring_stats_anomalies( - model_deployment_monitoring_job='model_deployment_monitoring_job_value', - deployed_model_id='deployed_model_id_value', + model_deployment_monitoring_job="model_deployment_monitoring_job_value", + deployed_model_id="deployed_model_id_value", ) # Establish that the underlying call was made with the expected @@ -6679,45 +6370,49 @@ def test_search_model_deployment_monitoring_stats_anomalies_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model_deployment_monitoring_job == 'model_deployment_monitoring_job_value' + assert ( + args[0].model_deployment_monitoring_job + == "model_deployment_monitoring_job_value" + ) - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" def test_search_model_deployment_monitoring_stats_anomalies_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_model_deployment_monitoring_stats_anomalies( job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(), - model_deployment_monitoring_job='model_deployment_monitoring_job_value', - deployed_model_id='deployed_model_id_value', + model_deployment_monitoring_job="model_deployment_monitoring_job_value", + deployed_model_id="deployed_model_id_value", ) @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: # Designate an appropriate return value for the call. - call.return_value = job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + call.return_value = ( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + ) - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.search_model_deployment_monitoring_stats_anomalies( - model_deployment_monitoring_job='model_deployment_monitoring_job_value', - deployed_model_id='deployed_model_id_value', + model_deployment_monitoring_job="model_deployment_monitoring_job_value", + deployed_model_id="deployed_model_id_value", ) # Establish that the underlying call was made with the expected @@ -6725,36 +6420,36 @@ async def test_search_model_deployment_monitoring_stats_anomalies_flattened_asyn assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model_deployment_monitoring_job == 'model_deployment_monitoring_job_value' + assert ( + args[0].model_deployment_monitoring_job + == "model_deployment_monitoring_job_value" + ) - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.search_model_deployment_monitoring_stats_anomalies( job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(), - model_deployment_monitoring_job='model_deployment_monitoring_job_value', - deployed_model_id='deployed_model_id_value', + model_deployment_monitoring_job="model_deployment_monitoring_job_value", + deployed_model_id="deployed_model_id_value", ) def test_search_model_deployment_monitoring_stats_anomalies_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( @@ -6763,17 +6458,16 @@ def test_search_model_deployment_monitoring_stats_anomalies_pager(): gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( - monitoring_stats=[], - next_page_token='def', + monitoring_stats=[], next_page_token="def", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ @@ -6786,9 +6480,9 @@ def test_search_model_deployment_monitoring_stats_anomalies_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model_deployment_monitoring_job', ''), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model_deployment_monitoring_job", ""),) + ), ) pager = client.search_model_deployment_monitoring_stats_anomalies(request={}) @@ -6796,18 +6490,22 @@ def test_search_model_deployment_monitoring_stats_anomalies_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies) - for i in results) + assert all( + isinstance( + i, gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies + ) + for i in results + ) + def test_search_model_deployment_monitoring_stats_anomalies_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__') as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( @@ -6816,17 +6514,16 @@ def test_search_model_deployment_monitoring_stats_anomalies_pages(): gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( - monitoring_stats=[], - next_page_token='def', + monitoring_stats=[], next_page_token="def", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ @@ -6836,20 +6533,23 @@ def test_search_model_deployment_monitoring_stats_anomalies_pages(): ), RuntimeError, ) - pages = list(client.search_model_deployment_monitoring_stats_anomalies(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + pages = list( + client.search_model_deployment_monitoring_stats_anomalies(request={}).pages + ) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( @@ -6858,17 +6558,16 @@ async def test_search_model_deployment_monitoring_stats_anomalies_async_pager(): gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( - monitoring_stats=[], - next_page_token='def', + monitoring_stats=[], next_page_token="def", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ @@ -6878,26 +6577,33 @@ async def test_search_model_deployment_monitoring_stats_anomalies_async_pager(): ), RuntimeError, ) - async_pager = await client.search_model_deployment_monitoring_stats_anomalies(request={},) - assert async_pager.next_page_token == 'abc' + async_pager = await client.search_model_deployment_monitoring_stats_anomalies( + request={}, + ) + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies) - for i in responses) + assert all( + isinstance( + i, gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies + ) + for i in responses + ) + @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_model_deployment_monitoring_stats_anomalies), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_model_deployment_monitoring_stats_anomalies), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( @@ -6906,17 +6612,16 @@ async def test_search_model_deployment_monitoring_stats_anomalies_async_pages(): gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( - monitoring_stats=[], - next_page_token='def', + monitoring_stats=[], next_page_token="def", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ gca_model_deployment_monitoring_job.ModelMonitoringStatsAnomalies(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse( monitoring_stats=[ @@ -6927,16 +6632,20 @@ async def test_search_model_deployment_monitoring_stats_anomalies_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.search_model_deployment_monitoring_stats_anomalies(request={})).pages: + async for page_ in ( + await client.search_model_deployment_monitoring_stats_anomalies(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.GetModelDeploymentMonitoringJobRequest): +def test_get_model_deployment_monitoring_job( + transport: str = "grpc", + request_type=job_service.GetModelDeploymentMonitoringJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6945,24 +6654,17 @@ def test_get_model_deployment_monitoring_job(transport: str = 'grpc', request_ty # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.get_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob( - name='name_value', - - display_name='display_name_value', - - endpoint='endpoint_value', - + name="name_value", + display_name="display_name_value", + endpoint="endpoint_value", state=job_state.JobState.JOB_STATE_QUEUED, - schedule_state=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, - - predict_instance_schema_uri='predict_instance_schema_uri_value', - - analysis_instance_schema_uri='analysis_instance_schema_uri_value', - + predict_instance_schema_uri="predict_instance_schema_uri_value", + analysis_instance_schema_uri="analysis_instance_schema_uri_value", ) response = client.get_model_deployment_monitoring_job(request) @@ -6975,21 +6677,26 @@ def test_get_model_deployment_monitoring_job(transport: str = 'grpc', request_ty # Establish that the response is the type that we expect. - assert isinstance(response, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + assert isinstance( + response, model_deployment_monitoring_job.ModelDeploymentMonitoringJob + ) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.endpoint == 'endpoint_value' + assert response.endpoint == "endpoint_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.schedule_state == model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + assert ( + response.schedule_state + == model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + ) - assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + assert response.predict_instance_schema_uri == "predict_instance_schema_uri_value" - assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + assert response.analysis_instance_schema_uri == "analysis_instance_schema_uri_value" def test_get_model_deployment_monitoring_job_from_dict(): @@ -7000,25 +6707,27 @@ def test_get_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.get_model_deployment_monitoring_job), "__call__" + ) as call: client.get_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.GetModelDeploymentMonitoringJobRequest() + @pytest.mark.asyncio -async def test_get_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetModelDeploymentMonitoringJobRequest): +async def test_get_model_deployment_monitoring_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetModelDeploymentMonitoringJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7027,18 +6736,20 @@ async def test_get_model_deployment_monitoring_job_async(transport: str = 'grpc_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.get_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_deployment_monitoring_job.ModelDeploymentMonitoringJob( - name='name_value', - display_name='display_name_value', - endpoint='endpoint_value', - state=job_state.JobState.JOB_STATE_QUEUED, - schedule_state=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, - predict_instance_schema_uri='predict_instance_schema_uri_value', - analysis_instance_schema_uri='analysis_instance_schema_uri_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value", + display_name="display_name_value", + endpoint="endpoint_value", + state=job_state.JobState.JOB_STATE_QUEUED, + schedule_state=model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING, + predict_instance_schema_uri="predict_instance_schema_uri_value", + analysis_instance_schema_uri="analysis_instance_schema_uri_value", + ) + ) response = await client.get_model_deployment_monitoring_job(request) @@ -7049,21 +6760,26 @@ async def test_get_model_deployment_monitoring_job_async(transport: str = 'grpc_ assert args[0] == job_service.GetModelDeploymentMonitoringJobRequest() # Establish that the response is the type that we expect. - assert isinstance(response, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + assert isinstance( + response, model_deployment_monitoring_job.ModelDeploymentMonitoringJob + ) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.endpoint == 'endpoint_value' + assert response.endpoint == "endpoint_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.schedule_state == model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + assert ( + response.schedule_state + == model_deployment_monitoring_job.ModelDeploymentMonitoringJob.MonitoringScheduleState.PENDING + ) - assert response.predict_instance_schema_uri == 'predict_instance_schema_uri_value' + assert response.predict_instance_schema_uri == "predict_instance_schema_uri_value" - assert response.analysis_instance_schema_uri == 'analysis_instance_schema_uri_value' + assert response.analysis_instance_schema_uri == "analysis_instance_schema_uri_value" @pytest.mark.asyncio @@ -7072,20 +6788,20 @@ async def test_get_model_deployment_monitoring_job_async_from_dict(): def test_get_model_deployment_monitoring_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + type(client.transport.get_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = ( + model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) client.get_model_deployment_monitoring_job(request) @@ -7096,28 +6812,25 @@ def test_get_model_deployment_monitoring_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_deployment_monitoring_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + type(client.transport.get_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) await client.get_model_deployment_monitoring_job(request) @@ -7128,99 +6841,90 @@ async def test_get_model_deployment_monitoring_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_deployment_monitoring_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.get_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + call.return_value = ( + model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_deployment_monitoring_job( - name='name_value', - ) + client.get_model_deployment_monitoring_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_deployment_monitoring_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_deployment_monitoring_job( - job_service.GetModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.GetModelDeploymentMonitoringJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_deployment_monitoring_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.get_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + call.return_value = ( + model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_deployment_monitoring_job.ModelDeploymentMonitoringJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_deployment_monitoring_job.ModelDeploymentMonitoringJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_deployment_monitoring_job( - name='name_value', - ) + response = await client.get_model_deployment_monitoring_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_deployment_monitoring_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_deployment_monitoring_job( - job_service.GetModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.GetModelDeploymentMonitoringJobRequest(), name="name_value", ) -def test_list_model_deployment_monitoring_jobs(transport: str = 'grpc', request_type=job_service.ListModelDeploymentMonitoringJobsRequest): +def test_list_model_deployment_monitoring_jobs( + transport: str = "grpc", + request_type=job_service.ListModelDeploymentMonitoringJobsRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7229,12 +6933,11 @@ def test_list_model_deployment_monitoring_jobs(transport: str = 'grpc', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_deployment_monitoring_jobs(request) @@ -7249,7 +6952,7 @@ def test_list_model_deployment_monitoring_jobs(transport: str = 'grpc', request_ assert isinstance(response, pagers.ListModelDeploymentMonitoringJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_deployment_monitoring_jobs_from_dict(): @@ -7260,25 +6963,27 @@ def test_list_model_deployment_monitoring_jobs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: client.list_model_deployment_monitoring_jobs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ListModelDeploymentMonitoringJobsRequest() + @pytest.mark.asyncio -async def test_list_model_deployment_monitoring_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListModelDeploymentMonitoringJobsRequest): +async def test_list_model_deployment_monitoring_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListModelDeploymentMonitoringJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7287,12 +6992,14 @@ async def test_list_model_deployment_monitoring_jobs_async(transport: str = 'grp # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListModelDeploymentMonitoringJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListModelDeploymentMonitoringJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_deployment_monitoring_jobs(request) @@ -7305,7 +7012,7 @@ async def test_list_model_deployment_monitoring_jobs_async(transport: str = 'grp # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelDeploymentMonitoringJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -7314,19 +7021,17 @@ async def test_list_model_deployment_monitoring_jobs_async_from_dict(): def test_list_model_deployment_monitoring_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListModelDeploymentMonitoringJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse() client.list_model_deployment_monitoring_jobs(request) @@ -7338,28 +7043,25 @@ def test_list_model_deployment_monitoring_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListModelDeploymentMonitoringJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListModelDeploymentMonitoringJobsResponse()) + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListModelDeploymentMonitoringJobsResponse() + ) await client.list_model_deployment_monitoring_jobs(request) @@ -7370,70 +7072,61 @@ async def test_list_model_deployment_monitoring_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_deployment_monitoring_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_deployment_monitoring_jobs( - parent='parent_value', - ) + client.list_model_deployment_monitoring_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_deployment_monitoring_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_deployment_monitoring_jobs( job_service.ListModelDeploymentMonitoringJobsRequest(), - parent='parent_value', + parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListModelDeploymentMonitoringJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListModelDeploymentMonitoringJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListModelDeploymentMonitoringJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.list_model_deployment_monitoring_jobs( - parent='parent_value', + parent="parent_value", ) # Establish that the underlying call was made with the expected @@ -7441,33 +7134,29 @@ async def test_list_model_deployment_monitoring_jobs_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_deployment_monitoring_jobs( job_service.ListModelDeploymentMonitoringJobsRequest(), - parent='parent_value', + parent="parent_value", ) def test_list_model_deployment_monitoring_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListModelDeploymentMonitoringJobsResponse( @@ -7476,17 +7165,16 @@ def test_list_model_deployment_monitoring_jobs_pager(): model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListModelDeploymentMonitoringJobsResponse( - model_deployment_monitoring_jobs=[], - next_page_token='def', + model_deployment_monitoring_jobs=[], next_page_token="def", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ @@ -7499,9 +7187,7 @@ def test_list_model_deployment_monitoring_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_deployment_monitoring_jobs(request={}) @@ -7509,18 +7195,19 @@ def test_list_model_deployment_monitoring_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) - for i in results) + assert all( + isinstance(i, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + for i in results + ) + def test_list_model_deployment_monitoring_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__') as call: + type(client.transport.list_model_deployment_monitoring_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListModelDeploymentMonitoringJobsResponse( @@ -7529,17 +7216,16 @@ def test_list_model_deployment_monitoring_jobs_pages(): model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListModelDeploymentMonitoringJobsResponse( - model_deployment_monitoring_jobs=[], - next_page_token='def', + model_deployment_monitoring_jobs=[], next_page_token="def", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ @@ -7550,19 +7236,20 @@ def test_list_model_deployment_monitoring_jobs_pages(): RuntimeError, ) pages = list(client.list_model_deployment_monitoring_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_deployment_monitoring_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListModelDeploymentMonitoringJobsResponse( @@ -7571,17 +7258,16 @@ async def test_list_model_deployment_monitoring_jobs_async_pager(): model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListModelDeploymentMonitoringJobsResponse( - model_deployment_monitoring_jobs=[], - next_page_token='def', + model_deployment_monitoring_jobs=[], next_page_token="def", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ @@ -7592,25 +7278,28 @@ async def test_list_model_deployment_monitoring_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_model_deployment_monitoring_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) - for i in responses) + assert all( + isinstance(i, model_deployment_monitoring_job.ModelDeploymentMonitoringJob) + for i in responses + ) + @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_deployment_monitoring_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_deployment_monitoring_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListModelDeploymentMonitoringJobsResponse( @@ -7619,17 +7308,16 @@ async def test_list_model_deployment_monitoring_jobs_async_pages(): model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListModelDeploymentMonitoringJobsResponse( - model_deployment_monitoring_jobs=[], - next_page_token='def', + model_deployment_monitoring_jobs=[], next_page_token="def", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ model_deployment_monitoring_job.ModelDeploymentMonitoringJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListModelDeploymentMonitoringJobsResponse( model_deployment_monitoring_jobs=[ @@ -7640,16 +7328,20 @@ async def test_list_model_deployment_monitoring_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_model_deployment_monitoring_jobs(request={})).pages: + async for page_ in ( + await client.list_model_deployment_monitoring_jobs(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.UpdateModelDeploymentMonitoringJobRequest): +def test_update_model_deployment_monitoring_job( + transport: str = "grpc", + request_type=job_service.UpdateModelDeploymentMonitoringJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7658,10 +7350,10 @@ def test_update_model_deployment_monitoring_job(transport: str = 'grpc', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.update_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.update_model_deployment_monitoring_job(request) @@ -7683,25 +7375,27 @@ def test_update_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.update_model_deployment_monitoring_job), "__call__" + ) as call: client.update_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() + @pytest.mark.asyncio -async def test_update_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.UpdateModelDeploymentMonitoringJobRequest): +async def test_update_model_deployment_monitoring_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.UpdateModelDeploymentMonitoringJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7710,11 +7404,11 @@ async def test_update_model_deployment_monitoring_job_async(transport: str = 'gr # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.update_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.update_model_deployment_monitoring_job(request) @@ -7735,20 +7429,20 @@ async def test_update_model_deployment_monitoring_job_async_from_dict(): def test_update_model_deployment_monitoring_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.UpdateModelDeploymentMonitoringJobRequest() - request.model_deployment_monitoring_job.name = 'model_deployment_monitoring_job.name/value' + request.model_deployment_monitoring_job.name = ( + "model_deployment_monitoring_job.name/value" + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.update_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.update_model_deployment_monitoring_job(request) @@ -7760,27 +7454,29 @@ def test_update_model_deployment_monitoring_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model_deployment_monitoring_job.name=model_deployment_monitoring_job.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "model_deployment_monitoring_job.name=model_deployment_monitoring_job.name/value", + ) in kw["metadata"] @pytest.mark.asyncio async def test_update_model_deployment_monitoring_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.UpdateModelDeploymentMonitoringJobRequest() - request.model_deployment_monitoring_job.name = 'model_deployment_monitoring_job.name/value' + request.model_deployment_monitoring_job.name = ( + "model_deployment_monitoring_job.name/value" + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.update_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.update_model_deployment_monitoring_job(request) @@ -7792,28 +7488,28 @@ async def test_update_model_deployment_monitoring_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model_deployment_monitoring_job.name=model_deployment_monitoring_job.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "model_deployment_monitoring_job.name=model_deployment_monitoring_job.name/value", + ) in kw["metadata"] def test_update_model_deployment_monitoring_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.update_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model_deployment_monitoring_job( - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -7821,47 +7517,51 @@ def test_update_model_deployment_monitoring_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + assert args[ + 0 + ].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_model_deployment_monitoring_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model_deployment_monitoring_job( job_service.UpdateModelDeploymentMonitoringJobRequest(), - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_model_deployment_monitoring_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.update_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model_deployment_monitoring_job( - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -7869,31 +7569,37 @@ async def test_update_model_deployment_monitoring_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value') + assert args[ + 0 + ].model_deployment_monitoring_job == gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_model_deployment_monitoring_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model_deployment_monitoring_job( job_service.UpdateModelDeploymentMonitoringJobRequest(), - model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model_deployment_monitoring_job=gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.DeleteModelDeploymentMonitoringJobRequest): +def test_delete_model_deployment_monitoring_job( + transport: str = "grpc", + request_type=job_service.DeleteModelDeploymentMonitoringJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7902,10 +7608,10 @@ def test_delete_model_deployment_monitoring_job(transport: str = 'grpc', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.delete_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_model_deployment_monitoring_job(request) @@ -7927,25 +7633,27 @@ def test_delete_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.delete_model_deployment_monitoring_job), "__call__" + ) as call: client.delete_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.DeleteModelDeploymentMonitoringJobRequest() + @pytest.mark.asyncio -async def test_delete_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteModelDeploymentMonitoringJobRequest): +async def test_delete_model_deployment_monitoring_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteModelDeploymentMonitoringJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7954,11 +7662,11 @@ async def test_delete_model_deployment_monitoring_job_async(transport: str = 'gr # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.delete_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_model_deployment_monitoring_job(request) @@ -7979,20 +7687,18 @@ async def test_delete_model_deployment_monitoring_job_async_from_dict(): def test_delete_model_deployment_monitoring_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_model_deployment_monitoring_job(request) @@ -8003,28 +7709,25 @@ def test_delete_model_deployment_monitoring_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_model_deployment_monitoring_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_model_deployment_monitoring_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_model_deployment_monitoring_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_model_deployment_monitoring_job(request) @@ -8035,72 +7738,60 @@ async def test_delete_model_deployment_monitoring_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_model_deployment_monitoring_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.delete_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model_deployment_monitoring_job( - name='name_value', - ) + client.delete_model_deployment_monitoring_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_model_deployment_monitoring_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model_deployment_monitoring_job( - job_service.DeleteModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.DeleteModelDeploymentMonitoringJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_model_deployment_monitoring_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.delete_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.delete_model_deployment_monitoring_job( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -8108,28 +7799,27 @@ async def test_delete_model_deployment_monitoring_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_model_deployment_monitoring_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model_deployment_monitoring_job( - job_service.DeleteModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.DeleteModelDeploymentMonitoringJobRequest(), name="name_value", ) -def test_pause_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.PauseModelDeploymentMonitoringJobRequest): +def test_pause_model_deployment_monitoring_job( + transport: str = "grpc", + request_type=job_service.PauseModelDeploymentMonitoringJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -8138,8 +7828,8 @@ def test_pause_model_deployment_monitoring_job(transport: str = 'grpc', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.pause_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.pause_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -8163,25 +7853,27 @@ def test_pause_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.pause_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.pause_model_deployment_monitoring_job), "__call__" + ) as call: client.pause_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.PauseModelDeploymentMonitoringJobRequest() + @pytest.mark.asyncio -async def test_pause_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.PauseModelDeploymentMonitoringJobRequest): +async def test_pause_model_deployment_monitoring_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.PauseModelDeploymentMonitoringJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -8190,8 +7882,8 @@ async def test_pause_model_deployment_monitoring_job_async(transport: str = 'grp # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.pause_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.pause_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -8213,19 +7905,17 @@ async def test_pause_model_deployment_monitoring_job_async_from_dict(): def test_pause_model_deployment_monitoring_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.PauseModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.pause_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.pause_model_deployment_monitoring_job), "__call__" + ) as call: call.return_value = None client.pause_model_deployment_monitoring_job(request) @@ -8237,27 +7927,22 @@ def test_pause_model_deployment_monitoring_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_pause_model_deployment_monitoring_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.PauseModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.pause_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.pause_model_deployment_monitoring_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.pause_model_deployment_monitoring_job(request) @@ -8269,62 +7954,50 @@ async def test_pause_model_deployment_monitoring_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_pause_model_deployment_monitoring_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.pause_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.pause_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.pause_model_deployment_monitoring_job( - name='name_value', - ) + client.pause_model_deployment_monitoring_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_pause_model_deployment_monitoring_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.pause_model_deployment_monitoring_job( - job_service.PauseModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.PauseModelDeploymentMonitoringJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_pause_model_deployment_monitoring_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.pause_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.pause_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -8332,7 +8005,7 @@ async def test_pause_model_deployment_monitoring_job_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.pause_model_deployment_monitoring_job( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -8340,28 +8013,27 @@ async def test_pause_model_deployment_monitoring_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_pause_model_deployment_monitoring_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.pause_model_deployment_monitoring_job( - job_service.PauseModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.PauseModelDeploymentMonitoringJobRequest(), name="name_value", ) -def test_resume_model_deployment_monitoring_job(transport: str = 'grpc', request_type=job_service.ResumeModelDeploymentMonitoringJobRequest): +def test_resume_model_deployment_monitoring_job( + transport: str = "grpc", + request_type=job_service.ResumeModelDeploymentMonitoringJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -8370,8 +8042,8 @@ def test_resume_model_deployment_monitoring_job(transport: str = 'grpc', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.resume_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.resume_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -8395,25 +8067,27 @@ def test_resume_model_deployment_monitoring_job_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.resume_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.resume_model_deployment_monitoring_job), "__call__" + ) as call: client.resume_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.ResumeModelDeploymentMonitoringJobRequest() + @pytest.mark.asyncio -async def test_resume_model_deployment_monitoring_job_async(transport: str = 'grpc_asyncio', request_type=job_service.ResumeModelDeploymentMonitoringJobRequest): +async def test_resume_model_deployment_monitoring_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.ResumeModelDeploymentMonitoringJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -8422,8 +8096,8 @@ async def test_resume_model_deployment_monitoring_job_async(transport: str = 'gr # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.resume_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.resume_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -8445,19 +8119,17 @@ async def test_resume_model_deployment_monitoring_job_async_from_dict(): def test_resume_model_deployment_monitoring_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ResumeModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.resume_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.resume_model_deployment_monitoring_job), "__call__" + ) as call: call.return_value = None client.resume_model_deployment_monitoring_job(request) @@ -8469,27 +8141,22 @@ def test_resume_model_deployment_monitoring_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_resume_model_deployment_monitoring_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ResumeModelDeploymentMonitoringJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.resume_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.resume_model_deployment_monitoring_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.resume_model_deployment_monitoring_job(request) @@ -8501,62 +8168,50 @@ async def test_resume_model_deployment_monitoring_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_resume_model_deployment_monitoring_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.resume_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.resume_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.resume_model_deployment_monitoring_job( - name='name_value', - ) + client.resume_model_deployment_monitoring_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_resume_model_deployment_monitoring_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.resume_model_deployment_monitoring_job( - job_service.ResumeModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.ResumeModelDeploymentMonitoringJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_resume_model_deployment_monitoring_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.resume_model_deployment_monitoring_job), - '__call__') as call: + type(client.transport.resume_model_deployment_monitoring_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -8564,7 +8219,7 @@ async def test_resume_model_deployment_monitoring_job_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.resume_model_deployment_monitoring_job( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -8572,21 +8227,18 @@ async def test_resume_model_deployment_monitoring_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_resume_model_deployment_monitoring_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.resume_model_deployment_monitoring_job( - job_service.ResumeModelDeploymentMonitoringJobRequest(), - name='name_value', + job_service.ResumeModelDeploymentMonitoringJobRequest(), name="name_value", ) @@ -8597,8 +8249,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -8617,8 +8268,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -8646,13 +8296,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.JobServiceGrpcTransport, - transports.JobServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport,], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -8660,13 +8310,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.JobServiceGrpcTransport, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.JobServiceGrpcTransport,) def test_job_service_base_transport_error(): @@ -8674,13 +8319,15 @@ def test_job_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_job_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -8689,35 +8336,35 @@ def test_job_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_custom_job', - 'get_custom_job', - 'list_custom_jobs', - 'delete_custom_job', - 'cancel_custom_job', - 'create_data_labeling_job', - 'get_data_labeling_job', - 'list_data_labeling_jobs', - 'delete_data_labeling_job', - 'cancel_data_labeling_job', - 'create_hyperparameter_tuning_job', - 'get_hyperparameter_tuning_job', - 'list_hyperparameter_tuning_jobs', - 'delete_hyperparameter_tuning_job', - 'cancel_hyperparameter_tuning_job', - 'create_batch_prediction_job', - 'get_batch_prediction_job', - 'list_batch_prediction_jobs', - 'delete_batch_prediction_job', - 'cancel_batch_prediction_job', - 'create_model_deployment_monitoring_job', - 'search_model_deployment_monitoring_stats_anomalies', - 'get_model_deployment_monitoring_job', - 'list_model_deployment_monitoring_jobs', - 'update_model_deployment_monitoring_job', - 'delete_model_deployment_monitoring_job', - 'pause_model_deployment_monitoring_job', - 'resume_model_deployment_monitoring_job', - ) + "create_custom_job", + "get_custom_job", + "list_custom_jobs", + "delete_custom_job", + "cancel_custom_job", + "create_data_labeling_job", + "get_data_labeling_job", + "list_data_labeling_jobs", + "delete_data_labeling_job", + "cancel_data_labeling_job", + "create_hyperparameter_tuning_job", + "get_hyperparameter_tuning_job", + "list_hyperparameter_tuning_jobs", + "delete_hyperparameter_tuning_job", + "cancel_hyperparameter_tuning_job", + "create_batch_prediction_job", + "get_batch_prediction_job", + "list_batch_prediction_jobs", + "delete_batch_prediction_job", + "cancel_batch_prediction_job", + "create_model_deployment_monitoring_job", + "search_model_deployment_monitoring_stats_anomalies", + "get_model_deployment_monitoring_job", + "list_model_deployment_monitoring_jobs", + "update_model_deployment_monitoring_job", + "delete_model_deployment_monitoring_job", + "pause_model_deployment_monitoring_job", + "resume_model_deployment_monitoring_job", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -8730,23 +8377,28 @@ def test_job_service_base_transport(): def test_job_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_job_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport() @@ -8755,11 +8407,11 @@ def test_job_service_base_transport_with_adc(): def test_job_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) JobServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -8767,19 +8419,22 @@ def test_job_service_auth_adc(): def test_job_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.JobServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -8788,15 +8443,13 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -8811,38 +8464,40 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_host_with_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_job_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -8850,12 +8505,11 @@ def test_job_service_grpc_transport_channel(): def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -8864,12 +8518,17 @@ def test_job_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -8878,7 +8537,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -8894,9 +8553,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -8910,17 +8567,20 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -8937,9 +8597,7 @@ def test_job_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -8952,16 +8610,12 @@ def test_job_service_transport_channel_mtls_with_adc( def test_job_service_grpc_lro_client(): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -8969,16 +8623,12 @@ def test_job_service_grpc_lro_client(): def test_job_service_grpc_lro_async_client(): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -8989,17 +8639,20 @@ def test_batch_prediction_job_path(): location = "clam" batch_prediction_job = "whelk" - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) - actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, location=location, batch_prediction_job=batch_prediction_job, + ) + actual = JobServiceClient.batch_prediction_job_path( + project, location, batch_prediction_job + ) assert expected == actual def test_parse_batch_prediction_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", } path = JobServiceClient.batch_prediction_job_path(**expected) @@ -9007,22 +8660,24 @@ def test_parse_batch_prediction_job_path(): actual = JobServiceClient.parse_batch_prediction_job_path(path) assert expected == actual + def test_custom_job_path(): project = "cuttlefish" location = "mussel" custom_job = "winkle" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) actual = JobServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", - + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", } path = JobServiceClient.custom_job_path(**expected) @@ -9030,22 +8685,26 @@ def test_parse_custom_job_path(): actual = JobServiceClient.parse_custom_job_path(path) assert expected == actual + def test_data_labeling_job_path(): project = "squid" location = "clam" data_labeling_job = "whelk" - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) - actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + actual = JobServiceClient.data_labeling_job_path( + project, location, data_labeling_job + ) assert expected == actual def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", } path = JobServiceClient.data_labeling_job_path(**expected) @@ -9053,22 +8712,24 @@ def test_parse_data_labeling_job_path(): actual = JobServiceClient.parse_data_labeling_job_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = JobServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = JobServiceClient.dataset_path(**expected) @@ -9076,22 +8737,24 @@ def test_parse_dataset_path(): actual = JobServiceClient.parse_dataset_path(path) assert expected == actual + def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = JobServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = JobServiceClient.endpoint_path(**expected) @@ -9099,22 +8762,28 @@ def test_parse_endpoint_path(): actual = JobServiceClient.parse_endpoint_path(path) assert expected == actual + def test_hyperparameter_tuning_job_path(): project = "cuttlefish" location = "mussel" hyperparameter_tuning_job = "winkle" - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) - actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + actual = JobServiceClient.hyperparameter_tuning_job_path( + project, location, hyperparameter_tuning_job + ) assert expected == actual def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "hyperparameter_tuning_job": "abalone", - + "project": "nautilus", + "location": "scallop", + "hyperparameter_tuning_job": "abalone", } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -9122,22 +8791,24 @@ def test_parse_hyperparameter_tuning_job_path(): actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) assert expected == actual + def test_model_path(): project = "squid" location = "clam" model = "whelk" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = JobServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "octopus", - "location": "oyster", - "model": "nudibranch", - + "project": "octopus", + "location": "oyster", + "model": "nudibranch", } path = JobServiceClient.model_path(**expected) @@ -9145,22 +8816,28 @@ def test_parse_model_path(): actual = JobServiceClient.parse_model_path(path) assert expected == actual + def test_model_deployment_monitoring_job_path(): project = "cuttlefish" location = "mussel" model_deployment_monitoring_job = "winkle" - expected = "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format(project=project, location=location, model_deployment_monitoring_job=model_deployment_monitoring_job, ) - actual = JobServiceClient.model_deployment_monitoring_job_path(project, location, model_deployment_monitoring_job) + expected = "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format( + project=project, + location=location, + model_deployment_monitoring_job=model_deployment_monitoring_job, + ) + actual = JobServiceClient.model_deployment_monitoring_job_path( + project, location, model_deployment_monitoring_job + ) assert expected == actual def test_parse_model_deployment_monitoring_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model_deployment_monitoring_job": "abalone", - + "project": "nautilus", + "location": "scallop", + "model_deployment_monitoring_job": "abalone", } path = JobServiceClient.model_deployment_monitoring_job_path(**expected) @@ -9168,24 +8845,26 @@ def test_parse_model_deployment_monitoring_job_path(): actual = JobServiceClient.parse_model_deployment_monitoring_job_path(path) assert expected == actual + def test_trial_path(): project = "squid" location = "clam" study = "whelk" trial = "octopus" - expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) actual = JobServiceClient.trial_path(project, location, study, trial) assert expected == actual def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", - + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", } path = JobServiceClient.trial_path(**expected) @@ -9193,18 +8872,20 @@ def test_parse_trial_path(): actual = JobServiceClient.parse_trial_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = JobServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } path = JobServiceClient.common_billing_account_path(**expected) @@ -9212,18 +8893,18 @@ def test_parse_common_billing_account_path(): actual = JobServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = JobServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = JobServiceClient.common_folder_path(**expected) @@ -9231,18 +8912,18 @@ def test_parse_common_folder_path(): actual = JobServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = JobServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = JobServiceClient.common_organization_path(**expected) @@ -9250,18 +8931,18 @@ def test_parse_common_organization_path(): actual = JobServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = JobServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = JobServiceClient.common_project_path(**expected) @@ -9269,20 +8950,22 @@ def test_parse_common_project_path(): actual = JobServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = JobServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = JobServiceClient.common_location_path(**expected) @@ -9294,17 +8977,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = JobServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index 0a71403d33..7ae26844a0 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.metadata_service import MetadataServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.metadata_service import MetadataServiceClient +from google.cloud.aiplatform_v1beta1.services.metadata_service import ( + MetadataServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.metadata_service import ( + MetadataServiceClient, +) from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers from google.cloud.aiplatform_v1beta1.services.metadata_service import transports from google.cloud.aiplatform_v1beta1.types import artifact @@ -69,7 +73,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -80,36 +88,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MetadataServiceClient._get_default_mtls_endpoint(None) is None - assert MetadataServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MetadataServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - MetadataServiceClient, - MetadataServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MetadataServiceClient, MetadataServiceAsyncClient,] +) def test_metadata_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - MetadataServiceClient, - MetadataServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MetadataServiceClient, MetadataServiceAsyncClient,] +) def test_metadata_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -119,7 +143,7 @@ def test_metadata_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_metadata_service_client_get_transport_class(): @@ -133,29 +157,44 @@ def test_metadata_service_client_get_transport_class(): assert transport == transports.MetadataServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(MetadataServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceClient)) -@mock.patch.object(MetadataServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceAsyncClient)) -def test_metadata_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MetadataServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceClient), +) +@mock.patch.object( + MetadataServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceAsyncClient), +) +def test_metadata_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MetadataServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MetadataServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MetadataServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MetadataServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -171,7 +210,7 @@ def test_metadata_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -187,7 +226,7 @@ def test_metadata_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -207,13 +246,15 @@ def test_metadata_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -226,26 +267,62 @@ def test_metadata_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc", "true"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc", "false"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(MetadataServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceClient)) -@mock.patch.object(MetadataServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MetadataServiceClient, + transports.MetadataServiceGrpcTransport, + "grpc", + "true", + ), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MetadataServiceClient, + transports.MetadataServiceGrpcTransport, + "grpc", + "false", + ), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MetadataServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceClient), +) +@mock.patch.object( + MetadataServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_metadata_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_metadata_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -268,10 +345,18 @@ def test_metadata_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -292,9 +377,14 @@ def test_metadata_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -308,16 +398,23 @@ def test_metadata_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_metadata_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_metadata_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -330,16 +427,24 @@ def test_metadata_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_metadata_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_metadata_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -354,10 +459,12 @@ def test_metadata_service_client_client_options_credentials_file(client_class, t def test_metadata_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MetadataServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -370,10 +477,11 @@ def test_metadata_service_client_client_options_from_dict(): ) -def test_create_metadata_store(transport: str = 'grpc', request_type=metadata_service.CreateMetadataStoreRequest): +def test_create_metadata_store( + transport: str = "grpc", request_type=metadata_service.CreateMetadataStoreRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -382,10 +490,10 @@ def test_create_metadata_store(transport: str = 'grpc', request_type=metadata_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_metadata_store(request) @@ -407,25 +515,27 @@ def test_create_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: client.create_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateMetadataStoreRequest() + @pytest.mark.asyncio -async def test_create_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateMetadataStoreRequest): +async def test_create_metadata_store_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.CreateMetadataStoreRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -434,11 +544,11 @@ async def test_create_metadata_store_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_metadata_store(request) @@ -459,20 +569,18 @@ async def test_create_metadata_store_async_from_dict(): def test_create_metadata_store_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataStoreRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_metadata_store), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_metadata_store(request) @@ -483,28 +591,25 @@ def test_create_metadata_store_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_metadata_store_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataStoreRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_metadata_store), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_metadata_store(request) @@ -515,30 +620,25 @@ async def test_create_metadata_store_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_metadata_store_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_metadata_store( - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) # Establish that the underlying call was made with the expected @@ -546,51 +646,49 @@ def test_create_metadata_store_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_store == gca_metadata_store.MetadataStore(name='name_value') + assert args[0].metadata_store == gca_metadata_store.MetadataStore( + name="name_value" + ) - assert args[0].metadata_store_id == 'metadata_store_id_value' + assert args[0].metadata_store_id == "metadata_store_id_value" def test_create_metadata_store_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_metadata_store( metadata_service.CreateMetadataStoreRequest(), - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) @pytest.mark.asyncio async def test_create_metadata_store_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_metadata_store( - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) # Establish that the underlying call was made with the expected @@ -598,34 +696,35 @@ async def test_create_metadata_store_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_store == gca_metadata_store.MetadataStore(name='name_value') + assert args[0].metadata_store == gca_metadata_store.MetadataStore( + name="name_value" + ) - assert args[0].metadata_store_id == 'metadata_store_id_value' + assert args[0].metadata_store_id == "metadata_store_id_value" @pytest.mark.asyncio async def test_create_metadata_store_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_metadata_store( metadata_service.CreateMetadataStoreRequest(), - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) -def test_get_metadata_store(transport: str = 'grpc', request_type=metadata_service.GetMetadataStoreRequest): +def test_get_metadata_store( + transport: str = "grpc", request_type=metadata_service.GetMetadataStoreRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -634,13 +733,10 @@ def test_get_metadata_store(transport: str = 'grpc', request_type=metadata_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_store.MetadataStore( - name='name_value', - - ) + call.return_value = metadata_store.MetadataStore(name="name_value",) response = client.get_metadata_store(request) @@ -654,7 +750,7 @@ def test_get_metadata_store(transport: str = 'grpc', request_type=metadata_servi assert isinstance(response, metadata_store.MetadataStore) - assert response.name == 'name_value' + assert response.name == "name_value" def test_get_metadata_store_from_dict(): @@ -665,25 +761,27 @@ def test_get_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: client.get_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetMetadataStoreRequest() + @pytest.mark.asyncio -async def test_get_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetMetadataStoreRequest): +async def test_get_metadata_store_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.GetMetadataStoreRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -692,12 +790,12 @@ async def test_get_metadata_store_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore( - name='name_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_store.MetadataStore(name="name_value",) + ) response = await client.get_metadata_store(request) @@ -710,7 +808,7 @@ async def test_get_metadata_store_async(transport: str = 'grpc_asyncio', request # Establish that the response is the type that we expect. assert isinstance(response, metadata_store.MetadataStore) - assert response.name == 'name_value' + assert response.name == "name_value" @pytest.mark.asyncio @@ -719,19 +817,17 @@ async def test_get_metadata_store_async_from_dict(): def test_get_metadata_store_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: call.return_value = metadata_store.MetadataStore() client.get_metadata_store(request) @@ -743,28 +839,25 @@ def test_get_metadata_store_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_metadata_store_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore()) + type(client.transport.get_metadata_store), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_store.MetadataStore() + ) await client.get_metadata_store(request) @@ -775,99 +868,85 @@ async def test_get_metadata_store_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_metadata_store_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_store.MetadataStore() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_metadata_store( - name='name_value', - ) + client.get_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_metadata_store_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_metadata_store( - metadata_service.GetMetadataStoreRequest(), - name='name_value', + metadata_service.GetMetadataStoreRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_metadata_store_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_store.MetadataStore() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_store.MetadataStore() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_metadata_store( - name='name_value', - ) + response = await client.get_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_metadata_store_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_metadata_store( - metadata_service.GetMetadataStoreRequest(), - name='name_value', + metadata_service.GetMetadataStoreRequest(), name="name_value", ) -def test_list_metadata_stores(transport: str = 'grpc', request_type=metadata_service.ListMetadataStoresRequest): +def test_list_metadata_stores( + transport: str = "grpc", request_type=metadata_service.ListMetadataStoresRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -876,12 +955,11 @@ def test_list_metadata_stores(transport: str = 'grpc', request_type=metadata_ser # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataStoresResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_metadata_stores(request) @@ -896,7 +974,7 @@ def test_list_metadata_stores(transport: str = 'grpc', request_type=metadata_ser assert isinstance(response, pagers.ListMetadataStoresPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_metadata_stores_from_dict(): @@ -907,25 +985,27 @@ def test_list_metadata_stores_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: client.list_metadata_stores() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListMetadataStoresRequest() + @pytest.mark.asyncio -async def test_list_metadata_stores_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListMetadataStoresRequest): +async def test_list_metadata_stores_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.ListMetadataStoresRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -934,12 +1014,14 @@ async def test_list_metadata_stores_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataStoresResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_metadata_stores(request) @@ -952,7 +1034,7 @@ async def test_list_metadata_stores_async(transport: str = 'grpc_asyncio', reque # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListMetadataStoresAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -961,19 +1043,17 @@ async def test_list_metadata_stores_async_from_dict(): def test_list_metadata_stores_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataStoresRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: call.return_value = metadata_service.ListMetadataStoresResponse() client.list_metadata_stores(request) @@ -985,28 +1065,25 @@ def test_list_metadata_stores_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_metadata_stores_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataStoresRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse()) + type(client.transport.list_metadata_stores), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataStoresResponse() + ) await client.list_metadata_stores(request) @@ -1017,104 +1094,87 @@ async def test_list_metadata_stores_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_metadata_stores_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataStoresResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_metadata_stores( - parent='parent_value', - ) + client.list_metadata_stores(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_metadata_stores_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_metadata_stores( - metadata_service.ListMetadataStoresRequest(), - parent='parent_value', + metadata_service.ListMetadataStoresRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_metadata_stores_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataStoresResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataStoresResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_metadata_stores( - parent='parent_value', - ) + response = await client.list_metadata_stores(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_metadata_stores_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_metadata_stores( - metadata_service.ListMetadataStoresRequest(), - parent='parent_value', + metadata_service.ListMetadataStoresRequest(), parent="parent_value", ) def test_list_metadata_stores_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1123,17 +1183,14 @@ def test_list_metadata_stores_pager(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1146,9 +1203,7 @@ def test_list_metadata_stores_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_metadata_stores(request={}) @@ -1156,18 +1211,16 @@ def test_list_metadata_stores_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, metadata_store.MetadataStore) - for i in results) + assert all(isinstance(i, metadata_store.MetadataStore) for i in results) + def test_list_metadata_stores_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1176,17 +1229,14 @@ def test_list_metadata_stores_pages(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1197,19 +1247,20 @@ def test_list_metadata_stores_pages(): RuntimeError, ) pages = list(client.list_metadata_stores(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_metadata_stores_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_stores), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1218,17 +1269,14 @@ async def test_list_metadata_stores_async_pager(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1239,25 +1287,25 @@ async def test_list_metadata_stores_async_pager(): RuntimeError, ) async_pager = await client.list_metadata_stores(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, metadata_store.MetadataStore) - for i in responses) + assert all(isinstance(i, metadata_store.MetadataStore) for i in responses) + @pytest.mark.asyncio async def test_list_metadata_stores_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_stores), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1266,17 +1314,14 @@ async def test_list_metadata_stores_async_pages(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1289,14 +1334,15 @@ async def test_list_metadata_stores_async_pages(): pages = [] async for page_ in (await client.list_metadata_stores(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_metadata_store(transport: str = 'grpc', request_type=metadata_service.DeleteMetadataStoreRequest): +def test_delete_metadata_store( + transport: str = "grpc", request_type=metadata_service.DeleteMetadataStoreRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1305,10 +1351,10 @@ def test_delete_metadata_store(transport: str = 'grpc', request_type=metadata_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_metadata_store(request) @@ -1330,25 +1376,27 @@ def test_delete_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: client.delete_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.DeleteMetadataStoreRequest() + @pytest.mark.asyncio -async def test_delete_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.DeleteMetadataStoreRequest): +async def test_delete_metadata_store_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.DeleteMetadataStoreRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1357,11 +1405,11 @@ async def test_delete_metadata_store_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_metadata_store(request) @@ -1382,20 +1430,18 @@ async def test_delete_metadata_store_async_from_dict(): def test_delete_metadata_store_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_metadata_store), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_metadata_store(request) @@ -1406,28 +1452,25 @@ def test_delete_metadata_store_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_metadata_store_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_metadata_store), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_metadata_store(request) @@ -1438,101 +1481,85 @@ async def test_delete_metadata_store_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_metadata_store_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_metadata_store( - name='name_value', - ) + client.delete_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_metadata_store_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_metadata_store( - metadata_service.DeleteMetadataStoreRequest(), - name='name_value', + metadata_service.DeleteMetadataStoreRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_metadata_store_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_metadata_store( - name='name_value', - ) + response = await client.delete_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_metadata_store_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_metadata_store( - metadata_service.DeleteMetadataStoreRequest(), - name='name_value', + metadata_service.DeleteMetadataStoreRequest(), name="name_value", ) -def test_create_artifact(transport: str = 'grpc', request_type=metadata_service.CreateArtifactRequest): +def test_create_artifact( + transport: str = "grpc", request_type=metadata_service.CreateArtifactRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1540,27 +1567,17 @@ def test_create_artifact(transport: str = 'grpc', request_type=metadata_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact( - name='name_value', - - display_name='display_name_value', - - uri='uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", state=gca_artifact.Artifact.State.PENDING, - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.create_artifact(request) @@ -1575,21 +1592,21 @@ def test_create_artifact(transport: str = 'grpc', request_type=metadata_service. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_artifact_from_dict(): @@ -1600,25 +1617,24 @@ def test_create_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: client.create_artifact() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateArtifactRequest() + @pytest.mark.asyncio -async def test_create_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateArtifactRequest): +async def test_create_artifact_async( + transport: str = "grpc_asyncio", request_type=metadata_service.CreateArtifactRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1626,20 +1642,20 @@ async def test_create_artifact_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact( - name='name_value', - display_name='display_name_value', - uri='uri_value', - etag='etag_value', - state=gca_artifact.Artifact.State.PENDING, - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact( + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=gca_artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.create_artifact(request) @@ -1652,21 +1668,21 @@ async def test_create_artifact_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -1675,19 +1691,15 @@ async def test_create_artifact_async_from_dict(): def test_create_artifact_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateArtifactRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: call.return_value = gca_artifact.Artifact() client.create_artifact(request) @@ -1699,28 +1711,23 @@ def test_create_artifact_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_artifact_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateArtifactRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) await client.create_artifact(request) @@ -1731,30 +1738,23 @@ async def test_create_artifact_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_artifact_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_artifact( - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) # Establish that the underlying call was made with the expected @@ -1762,49 +1762,45 @@ def test_create_artifact_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].artifact_id == 'artifact_id_value' + assert args[0].artifact_id == "artifact_id_value" def test_create_artifact_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_artifact( metadata_service.CreateArtifactRequest(), - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) @pytest.mark.asyncio async def test_create_artifact_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_artifact( - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) # Establish that the underlying call was made with the expected @@ -1812,34 +1808,33 @@ async def test_create_artifact_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].artifact_id == 'artifact_id_value' + assert args[0].artifact_id == "artifact_id_value" @pytest.mark.asyncio async def test_create_artifact_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_artifact( metadata_service.CreateArtifactRequest(), - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) -def test_get_artifact(transport: str = 'grpc', request_type=metadata_service.GetArtifactRequest): +def test_get_artifact( + transport: str = "grpc", request_type=metadata_service.GetArtifactRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1847,56 +1842,46 @@ def test_get_artifact(transport: str = 'grpc', request_type=metadata_service.Get request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = artifact.Artifact( - name='name_value', + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) - display_name='display_name_value', + response = client.get_artifact(request) - uri='uri_value', + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] - etag='etag_value', - - state=artifact.Artifact.State.PENDING, - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - - ) - - response = client.get_artifact(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == metadata_service.GetArtifactRequest() + assert args[0] == metadata_service.GetArtifactRequest() # Establish that the response is the type that we expect. assert isinstance(response, artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_artifact_from_dict(): @@ -1907,25 +1892,24 @@ def test_get_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: client.get_artifact() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetArtifactRequest() + @pytest.mark.asyncio -async def test_get_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetArtifactRequest): +async def test_get_artifact_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetArtifactRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1933,20 +1917,20 @@ async def test_get_artifact_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact( - name='name_value', - display_name='display_name_value', - uri='uri_value', - etag='etag_value', - state=artifact.Artifact.State.PENDING, - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + artifact.Artifact( + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.get_artifact(request) @@ -1959,21 +1943,21 @@ async def test_get_artifact_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -1982,19 +1966,15 @@ async def test_get_artifact_async_from_dict(): def test_get_artifact_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetArtifactRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: call.return_value = artifact.Artifact() client.get_artifact(request) @@ -2006,27 +1986,20 @@ def test_get_artifact_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_artifact_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetArtifactRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact()) await client.get_artifact(request) @@ -2038,99 +2011,79 @@ async def test_get_artifact_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_artifact_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = artifact.Artifact() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_artifact( - name='name_value', - ) + client.get_artifact(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_artifact_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_artifact( - metadata_service.GetArtifactRequest(), - name='name_value', + metadata_service.GetArtifactRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_artifact_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = artifact.Artifact() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_artifact( - name='name_value', - ) + response = await client.get_artifact(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_artifact_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_artifact( - metadata_service.GetArtifactRequest(), - name='name_value', + metadata_service.GetArtifactRequest(), name="name_value", ) -def test_list_artifacts(transport: str = 'grpc', request_type=metadata_service.ListArtifactsRequest): +def test_list_artifacts( + transport: str = "grpc", request_type=metadata_service.ListArtifactsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2138,13 +2091,10 @@ def test_list_artifacts(transport: str = 'grpc', request_type=metadata_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListArtifactsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_artifacts(request) @@ -2159,7 +2109,7 @@ def test_list_artifacts(transport: str = 'grpc', request_type=metadata_service.L assert isinstance(response, pagers.ListArtifactsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_artifacts_from_dict(): @@ -2170,25 +2120,24 @@ def test_list_artifacts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: client.list_artifacts() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListArtifactsRequest() + @pytest.mark.asyncio -async def test_list_artifacts_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListArtifactsRequest): +async def test_list_artifacts_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListArtifactsRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2196,13 +2145,13 @@ async def test_list_artifacts_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListArtifactsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_artifacts(request) @@ -2215,7 +2164,7 @@ async def test_list_artifacts_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListArtifactsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2224,19 +2173,15 @@ async def test_list_artifacts_async_from_dict(): def test_list_artifacts_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListArtifactsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: call.return_value = metadata_service.ListArtifactsResponse() client.list_artifacts(request) @@ -2248,28 +2193,23 @@ def test_list_artifacts_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_artifacts_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListArtifactsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse()) + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListArtifactsResponse() + ) await client.list_artifacts(request) @@ -2280,104 +2220,81 @@ async def test_list_artifacts_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_artifacts_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListArtifactsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_artifacts( - parent='parent_value', - ) + client.list_artifacts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_artifacts_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_artifacts( - metadata_service.ListArtifactsRequest(), - parent='parent_value', + metadata_service.ListArtifactsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_artifacts_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListArtifactsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListArtifactsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_artifacts( - parent='parent_value', - ) + response = await client.list_artifacts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_artifacts_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_artifacts( - metadata_service.ListArtifactsRequest(), - parent='parent_value', + metadata_service.ListArtifactsRequest(), parent="parent_value", ) def test_list_artifacts_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2386,32 +2303,23 @@ def test_list_artifacts_pager(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_artifacts(request={}) @@ -2419,18 +2327,14 @@ def test_list_artifacts_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, artifact.Artifact) - for i in results) + assert all(isinstance(i, artifact.Artifact) for i in results) + def test_list_artifacts_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2439,40 +2343,32 @@ def test_list_artifacts_pages(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) pages = list(client.list_artifacts(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_artifacts_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_artifacts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_artifacts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2481,46 +2377,37 @@ async def test_list_artifacts_async_pager(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) async_pager = await client.list_artifacts(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, artifact.Artifact) - for i in responses) + assert all(isinstance(i, artifact.Artifact) for i in responses) + @pytest.mark.asyncio async def test_list_artifacts_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_artifacts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_artifacts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2529,37 +2416,31 @@ async def test_list_artifacts_async_pages(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_artifacts(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_artifact(transport: str = 'grpc', request_type=metadata_service.UpdateArtifactRequest): +def test_update_artifact( + transport: str = "grpc", request_type=metadata_service.UpdateArtifactRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2567,27 +2448,17 @@ def test_update_artifact(transport: str = 'grpc', request_type=metadata_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact( - name='name_value', - - display_name='display_name_value', - - uri='uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", state=gca_artifact.Artifact.State.PENDING, - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.update_artifact(request) @@ -2602,21 +2473,21 @@ def test_update_artifact(transport: str = 'grpc', request_type=metadata_service. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_update_artifact_from_dict(): @@ -2627,25 +2498,24 @@ def test_update_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: client.update_artifact() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateArtifactRequest() + @pytest.mark.asyncio -async def test_update_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateArtifactRequest): +async def test_update_artifact_async( + transport: str = "grpc_asyncio", request_type=metadata_service.UpdateArtifactRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2653,20 +2523,20 @@ async def test_update_artifact_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact( - name='name_value', - display_name='display_name_value', - uri='uri_value', - etag='etag_value', - state=gca_artifact.Artifact.State.PENDING, - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact( + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=gca_artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.update_artifact(request) @@ -2679,21 +2549,21 @@ async def test_update_artifact_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -2702,19 +2572,15 @@ async def test_update_artifact_async_from_dict(): def test_update_artifact_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateArtifactRequest() - request.artifact.name = 'artifact.name/value' + request.artifact.name = "artifact.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: call.return_value = gca_artifact.Artifact() client.update_artifact(request) @@ -2726,28 +2592,25 @@ def test_update_artifact_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'artifact.name=artifact.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "artifact.name=artifact.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_artifact_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateArtifactRequest() - request.artifact.name = 'artifact.name/value' + request.artifact.name = "artifact.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) await client.update_artifact(request) @@ -2758,29 +2621,24 @@ async def test_update_artifact_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'artifact.name=artifact.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "artifact.name=artifact.name/value",) in kw[ + "metadata" + ] def test_update_artifact_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_artifact( - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2788,45 +2646,41 @@ def test_update_artifact_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_artifact_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_artifact( metadata_service.UpdateArtifactRequest(), - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_artifact_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_artifact( - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2834,31 +2688,30 @@ async def test_update_artifact_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_artifact_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_artifact( metadata_service.UpdateArtifactRequest(), - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_create_context(transport: str = 'grpc', request_type=metadata_service.CreateContextRequest): +def test_create_context( + transport: str = "grpc", request_type=metadata_service.CreateContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2866,25 +2719,16 @@ def test_create_context(transport: str = 'grpc', request_type=metadata_service.C request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - - parent_contexts=['parent_contexts_value'], - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.create_context(request) @@ -2899,19 +2743,19 @@ def test_create_context(transport: str = 'grpc', request_type=metadata_service.C assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_context_from_dict(): @@ -2922,25 +2766,24 @@ def test_create_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: client.create_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateContextRequest() + @pytest.mark.asyncio -async def test_create_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateContextRequest): +async def test_create_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.CreateContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2948,19 +2791,19 @@ async def test_create_context_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context( - name='name_value', - display_name='display_name_value', - etag='etag_value', - parent_contexts=['parent_contexts_value'], - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.create_context(request) @@ -2973,19 +2816,19 @@ async def test_create_context_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -2994,19 +2837,15 @@ async def test_create_context_async_from_dict(): def test_create_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateContextRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: call.return_value = gca_context.Context() client.create_context(request) @@ -3018,27 +2857,20 @@ def test_create_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateContextRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) await client.create_context(request) @@ -3050,30 +2882,23 @@ async def test_create_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_context( - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) # Establish that the underlying call was made with the expected @@ -3081,39 +2906,33 @@ def test_create_context_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].context_id == 'context_id_value' + assert args[0].context_id == "context_id_value" def test_create_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_context( metadata_service.CreateContextRequest(), - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) @pytest.mark.asyncio async def test_create_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() @@ -3121,9 +2940,9 @@ async def test_create_context_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_context( - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) # Establish that the underlying call was made with the expected @@ -3131,34 +2950,33 @@ async def test_create_context_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].context_id == 'context_id_value' + assert args[0].context_id == "context_id_value" @pytest.mark.asyncio async def test_create_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_context( metadata_service.CreateContextRequest(), - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) -def test_get_context(transport: str = 'grpc', request_type=metadata_service.GetContextRequest): +def test_get_context( + transport: str = "grpc", request_type=metadata_service.GetContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3166,25 +2984,16 @@ def test_get_context(transport: str = 'grpc', request_type=metadata_service.GetC request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = context.Context( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - - parent_contexts=['parent_contexts_value'], - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.get_context(request) @@ -3199,19 +3008,19 @@ def test_get_context(transport: str = 'grpc', request_type=metadata_service.GetC assert isinstance(response, context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_context_from_dict(): @@ -3222,25 +3031,24 @@ def test_get_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: client.get_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetContextRequest() + @pytest.mark.asyncio -async def test_get_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetContextRequest): +async def test_get_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3248,19 +3056,19 @@ async def test_get_context_async(transport: str = 'grpc_asyncio', request_type=m request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context( - name='name_value', - display_name='display_name_value', - etag='etag_value', - parent_contexts=['parent_contexts_value'], - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.get_context(request) @@ -3273,19 +3081,19 @@ async def test_get_context_async(transport: str = 'grpc_asyncio', request_type=m # Establish that the response is the type that we expect. assert isinstance(response, context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -3294,19 +3102,15 @@ async def test_get_context_async_from_dict(): def test_get_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: call.return_value = context.Context() client.get_context(request) @@ -3318,27 +3122,20 @@ def test_get_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) await client.get_context(request) @@ -3350,99 +3147,79 @@ async def test_get_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_context( - name='name_value', - ) + client.get_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_context( - metadata_service.GetContextRequest(), - name='name_value', + metadata_service.GetContextRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = context.Context() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_context( - name='name_value', - ) + response = await client.get_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_context( - metadata_service.GetContextRequest(), - name='name_value', + metadata_service.GetContextRequest(), name="name_value", ) -def test_list_contexts(transport: str = 'grpc', request_type=metadata_service.ListContextsRequest): +def test_list_contexts( + transport: str = "grpc", request_type=metadata_service.ListContextsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3450,13 +3227,10 @@ def test_list_contexts(transport: str = 'grpc', request_type=metadata_service.Li request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListContextsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_contexts(request) @@ -3471,7 +3245,7 @@ def test_list_contexts(transport: str = 'grpc', request_type=metadata_service.Li assert isinstance(response, pagers.ListContextsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_contexts_from_dict(): @@ -3482,25 +3256,24 @@ def test_list_contexts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: client.list_contexts() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListContextsRequest() + @pytest.mark.asyncio -async def test_list_contexts_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListContextsRequest): +async def test_list_contexts_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListContextsRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3508,13 +3281,13 @@ async def test_list_contexts_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_contexts(request) @@ -3527,7 +3300,7 @@ async def test_list_contexts_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListContextsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3536,19 +3309,15 @@ async def test_list_contexts_async_from_dict(): def test_list_contexts_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListContextsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: call.return_value = metadata_service.ListContextsResponse() client.list_contexts(request) @@ -3560,28 +3329,23 @@ def test_list_contexts_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_contexts_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListContextsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse()) + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse() + ) await client.list_contexts(request) @@ -3592,138 +3356,100 @@ async def test_list_contexts_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_list_contexts_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) +def test_list_contexts_flattened(): + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListContextsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_contexts( - parent='parent_value', - ) + client.list_contexts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_contexts_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_contexts( - metadata_service.ListContextsRequest(), - parent='parent_value', + metadata_service.ListContextsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_contexts_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListContextsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_contexts( - parent='parent_value', - ) + response = await client.list_contexts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_contexts_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_contexts( - metadata_service.ListContextsRequest(), - parent='parent_value', + metadata_service.ListContextsRequest(), parent="parent_value", ) def test_list_contexts_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_contexts(request={}) @@ -3731,147 +3457,102 @@ def test_list_contexts_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, context.Context) - for i in results) + assert all(isinstance(i, context.Context) for i in results) + def test_list_contexts_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) pages = list(client.list_contexts(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_contexts_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_contexts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) async_pager = await client.list_contexts(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, context.Context) - for i in responses) + assert all(isinstance(i, context.Context) for i in responses) + @pytest.mark.asyncio async def test_list_contexts_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_contexts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_contexts(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_context(transport: str = 'grpc', request_type=metadata_service.UpdateContextRequest): +def test_update_context( + transport: str = "grpc", request_type=metadata_service.UpdateContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3879,25 +3560,16 @@ def test_update_context(transport: str = 'grpc', request_type=metadata_service.U request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - - parent_contexts=['parent_contexts_value'], - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.update_context(request) @@ -3912,19 +3584,19 @@ def test_update_context(transport: str = 'grpc', request_type=metadata_service.U assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_update_context_from_dict(): @@ -3935,25 +3607,24 @@ def test_update_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: client.update_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateContextRequest() + @pytest.mark.asyncio -async def test_update_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateContextRequest): +async def test_update_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.UpdateContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3961,19 +3632,19 @@ async def test_update_context_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context( - name='name_value', - display_name='display_name_value', - etag='etag_value', - parent_contexts=['parent_contexts_value'], - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.update_context(request) @@ -3986,19 +3657,19 @@ async def test_update_context_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -4007,19 +3678,15 @@ async def test_update_context_async_from_dict(): def test_update_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateContextRequest() - request.context.name = 'context.name/value' + request.context.name = "context.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: call.return_value = gca_context.Context() client.update_context(request) @@ -4031,27 +3698,22 @@ def test_update_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context.name=context.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateContextRequest() - request.context.name = 'context.name/value' + request.context.name = "context.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) await client.update_context(request) @@ -4063,29 +3725,24 @@ async def test_update_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context.name=context.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ + "metadata" + ] def test_update_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_context( - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -4093,36 +3750,30 @@ def test_update_context_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_context( metadata_service.UpdateContextRequest(), - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() @@ -4130,8 +3781,8 @@ async def test_update_context_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_context( - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -4139,31 +3790,30 @@ async def test_update_context_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_context( metadata_service.UpdateContextRequest(), - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_context(transport: str = 'grpc', request_type=metadata_service.DeleteContextRequest): +def test_delete_context( + transport: str = "grpc", request_type=metadata_service.DeleteContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4171,11 +3821,9 @@ def test_delete_context(transport: str = 'grpc', request_type=metadata_service.D request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_context(request) @@ -4197,25 +3845,24 @@ def test_delete_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: client.delete_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.DeleteContextRequest() + @pytest.mark.asyncio -async def test_delete_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.DeleteContextRequest): +async def test_delete_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.DeleteContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4223,12 +3870,10 @@ async def test_delete_context_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_context(request) @@ -4249,20 +3894,16 @@ async def test_delete_context_async_from_dict(): def test_delete_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_context(request) @@ -4273,28 +3914,23 @@ def test_delete_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_context(request) @@ -4305,101 +3941,82 @@ async def test_delete_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_context( - name='name_value', - ) + client.delete_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_context( - metadata_service.DeleteContextRequest(), - name='name_value', + metadata_service.DeleteContextRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_context( - name='name_value', - ) + response = await client.delete_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_context( - metadata_service.DeleteContextRequest(), - name='name_value', + metadata_service.DeleteContextRequest(), name="name_value", ) -def test_add_context_artifacts_and_executions(transport: str = 'grpc', request_type=metadata_service.AddContextArtifactsAndExecutionsRequest): +def test_add_context_artifacts_and_executions( + transport: str = "grpc", + request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4408,11 +4025,10 @@ def test_add_context_artifacts_and_executions(transport: str = 'grpc', request_t # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse( - ) + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() response = client.add_context_artifacts_and_executions(request) @@ -4424,7 +4040,9 @@ def test_add_context_artifacts_and_executions(transport: str = 'grpc', request_t # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextArtifactsAndExecutionsResponse) + assert isinstance( + response, metadata_service.AddContextArtifactsAndExecutionsResponse + ) def test_add_context_artifacts_and_executions_from_dict(): @@ -4435,25 +4053,27 @@ def test_add_context_artifacts_and_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: client.add_context_artifacts_and_executions() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + @pytest.mark.asyncio -async def test_add_context_artifacts_and_executions_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddContextArtifactsAndExecutionsRequest): +async def test_add_context_artifacts_and_executions_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4462,11 +4082,12 @@ async def test_add_context_artifacts_and_executions_async(transport: str = 'grpc # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) response = await client.add_context_artifacts_and_executions(request) @@ -4477,7 +4098,9 @@ async def test_add_context_artifacts_and_executions_async(transport: str = 'grpc assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextArtifactsAndExecutionsResponse) + assert isinstance( + response, metadata_service.AddContextArtifactsAndExecutionsResponse + ) @pytest.mark.asyncio @@ -4486,19 +4109,17 @@ async def test_add_context_artifacts_and_executions_async_from_dict(): def test_add_context_artifacts_and_executions_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextArtifactsAndExecutionsRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() client.add_context_artifacts_and_executions(request) @@ -4510,28 +4131,25 @@ def test_add_context_artifacts_and_executions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextArtifactsAndExecutionsRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse()) + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) await client.add_context_artifacts_and_executions(request) @@ -4542,30 +4160,25 @@ async def test_add_context_artifacts_and_executions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] def test_add_context_artifacts_and_executions_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.add_context_artifacts_and_executions( - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) # Establish that the underlying call was made with the expected @@ -4573,49 +4186,47 @@ def test_add_context_artifacts_and_executions_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].artifacts == ['artifacts_value'] + assert args[0].artifacts == ["artifacts_value"] - assert args[0].executions == ['executions_value'] + assert args[0].executions == ["executions_value"] def test_add_context_artifacts_and_executions_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.add_context_artifacts_and_executions( metadata_service.AddContextArtifactsAndExecutionsRequest(), - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.add_context_artifacts_and_executions( - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) # Establish that the underlying call was made with the expected @@ -4623,34 +4234,33 @@ async def test_add_context_artifacts_and_executions_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].artifacts == ['artifacts_value'] + assert args[0].artifacts == ["artifacts_value"] - assert args[0].executions == ['executions_value'] + assert args[0].executions == ["executions_value"] @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.add_context_artifacts_and_executions( metadata_service.AddContextArtifactsAndExecutionsRequest(), - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) -def test_add_context_children(transport: str = 'grpc', request_type=metadata_service.AddContextChildrenRequest): +def test_add_context_children( + transport: str = "grpc", request_type=metadata_service.AddContextChildrenRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4659,11 +4269,10 @@ def test_add_context_children(transport: str = 'grpc', request_type=metadata_ser # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextChildrenResponse( - ) + call.return_value = metadata_service.AddContextChildrenResponse() response = client.add_context_children(request) @@ -4686,25 +4295,27 @@ def test_add_context_children_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: client.add_context_children() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.AddContextChildrenRequest() + @pytest.mark.asyncio -async def test_add_context_children_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddContextChildrenRequest): +async def test_add_context_children_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddContextChildrenRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4713,11 +4324,12 @@ async def test_add_context_children_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextChildrenResponse() + ) response = await client.add_context_children(request) @@ -4737,19 +4349,17 @@ async def test_add_context_children_async_from_dict(): def test_add_context_children_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextChildrenRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: call.return_value = metadata_service.AddContextChildrenResponse() client.add_context_children(request) @@ -4761,28 +4371,25 @@ def test_add_context_children_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] @pytest.mark.asyncio async def test_add_context_children_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextChildrenRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse()) + type(client.transport.add_context_children), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextChildrenResponse() + ) await client.add_context_children(request) @@ -4793,29 +4400,23 @@ async def test_add_context_children_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] def test_add_context_children_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextChildrenResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.add_context_children( - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", child_contexts=["child_contexts_value"], ) # Establish that the underlying call was made with the expected @@ -4823,45 +4424,42 @@ def test_add_context_children_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].child_contexts == ['child_contexts_value'] + assert args[0].child_contexts == ["child_contexts_value"] def test_add_context_children_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.add_context_children( metadata_service.AddContextChildrenRequest(), - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", + child_contexts=["child_contexts_value"], ) @pytest.mark.asyncio async def test_add_context_children_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextChildrenResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextChildrenResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.add_context_children( - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", child_contexts=["child_contexts_value"], ) # Establish that the underlying call was made with the expected @@ -4869,31 +4467,31 @@ async def test_add_context_children_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].child_contexts == ['child_contexts_value'] + assert args[0].child_contexts == ["child_contexts_value"] @pytest.mark.asyncio async def test_add_context_children_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.add_context_children( metadata_service.AddContextChildrenRequest(), - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", + child_contexts=["child_contexts_value"], ) -def test_query_context_lineage_subgraph(transport: str = 'grpc', request_type=metadata_service.QueryContextLineageSubgraphRequest): +def test_query_context_lineage_subgraph( + transport: str = "grpc", + request_type=metadata_service.QueryContextLineageSubgraphRequest, +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4902,11 +4500,10 @@ def test_query_context_lineage_subgraph(transport: str = 'grpc', request_type=me # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph( - ) + call.return_value = lineage_subgraph.LineageSubgraph() response = client.query_context_lineage_subgraph(request) @@ -4929,25 +4526,27 @@ def test_query_context_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: client.query_context_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + @pytest.mark.asyncio -async def test_query_context_lineage_subgraph_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryContextLineageSubgraphRequest): +async def test_query_context_lineage_subgraph_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.QueryContextLineageSubgraphRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4956,11 +4555,12 @@ async def test_query_context_lineage_subgraph_async(transport: str = 'grpc_async # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) response = await client.query_context_lineage_subgraph(request) @@ -4980,19 +4580,17 @@ async def test_query_context_lineage_subgraph_async_from_dict(): def test_query_context_lineage_subgraph_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryContextLineageSubgraphRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: call.return_value = lineage_subgraph.LineageSubgraph() client.query_context_lineage_subgraph(request) @@ -5004,28 +4602,25 @@ def test_query_context_lineage_subgraph_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] @pytest.mark.asyncio async def test_query_context_lineage_subgraph_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryContextLineageSubgraphRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) await client.query_context_lineage_subgraph(request) @@ -5036,125 +4631,104 @@ async def test_query_context_lineage_subgraph_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] def test_query_context_lineage_subgraph_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.query_context_lineage_subgraph( - context='context_value', - ) + client.query_context_lineage_subgraph(context="context_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" def test_query_context_lineage_subgraph_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.query_context_lineage_subgraph( metadata_service.QueryContextLineageSubgraphRequest(), - context='context_value', + context="context_value", ) @pytest.mark.asyncio async def test_query_context_lineage_subgraph_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.query_context_lineage_subgraph( - context='context_value', - ) + response = await client.query_context_lineage_subgraph(context="context_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" @pytest.mark.asyncio async def test_query_context_lineage_subgraph_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.query_context_lineage_subgraph( metadata_service.QueryContextLineageSubgraphRequest(), - context='context_value', + context="context_value", ) -def test_create_execution(transport: str = 'grpc', request_type=metadata_service.CreateExecutionRequest): +def test_create_execution( + transport: str = "grpc", request_type=metadata_service.CreateExecutionRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. request = request_type() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = gca_execution.Execution( - name='name_value', - - display_name='display_name_value', - - state=gca_execution.Execution.State.NEW, - - etag='etag_value', - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_execution.Execution( + name="name_value", + display_name="display_name_value", + state=gca_execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.create_execution(request) @@ -5169,19 +4743,19 @@ def test_create_execution(transport: str = 'grpc', request_type=metadata_service assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_execution_from_dict(): @@ -5192,25 +4766,25 @@ def test_create_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: client.create_execution() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateExecutionRequest() + @pytest.mark.asyncio -async def test_create_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateExecutionRequest): +async def test_create_execution_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.CreateExecutionRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5218,19 +4792,19 @@ async def test_create_execution_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution( - name='name_value', - display_name='display_name_value', - state=gca_execution.Execution.State.NEW, - etag='etag_value', - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution( + name="name_value", + display_name="display_name_value", + state=gca_execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.create_execution(request) @@ -5243,19 +4817,19 @@ async def test_create_execution_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -5264,19 +4838,15 @@ async def test_create_execution_async_from_dict(): def test_create_execution_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateExecutionRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: call.return_value = gca_execution.Execution() client.create_execution(request) @@ -5288,28 +4858,23 @@ def test_create_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_execution_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateExecutionRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) await client.create_execution(request) @@ -5320,30 +4885,23 @@ async def test_create_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_execution_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_execution( - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) # Establish that the underlying call was made with the expected @@ -5351,49 +4909,45 @@ def test_create_execution_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].execution_id == 'execution_id_value' + assert args[0].execution_id == "execution_id_value" def test_create_execution_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_execution( metadata_service.CreateExecutionRequest(), - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) @pytest.mark.asyncio async def test_create_execution_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_execution( - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) # Establish that the underlying call was made with the expected @@ -5401,34 +4955,33 @@ async def test_create_execution_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].execution_id == 'execution_id_value' + assert args[0].execution_id == "execution_id_value" @pytest.mark.asyncio async def test_create_execution_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_execution( metadata_service.CreateExecutionRequest(), - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) -def test_get_execution(transport: str = 'grpc', request_type=metadata_service.GetExecutionRequest): +def test_get_execution( + transport: str = "grpc", request_type=metadata_service.GetExecutionRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5436,25 +4989,16 @@ def test_get_execution(transport: str = 'grpc', request_type=metadata_service.Ge request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = execution.Execution( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=execution.Execution.State.NEW, - - etag='etag_value', - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.get_execution(request) @@ -5469,19 +5013,19 @@ def test_get_execution(transport: str = 'grpc', request_type=metadata_service.Ge assert isinstance(response, execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_execution_from_dict(): @@ -5492,25 +5036,24 @@ def test_get_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: client.get_execution() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetExecutionRequest() + @pytest.mark.asyncio -async def test_get_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetExecutionRequest): +async def test_get_execution_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetExecutionRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5518,19 +5061,19 @@ async def test_get_execution_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution( - name='name_value', - display_name='display_name_value', - state=execution.Execution.State.NEW, - etag='etag_value', - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + execution.Execution( + name="name_value", + display_name="display_name_value", + state=execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.get_execution(request) @@ -5543,19 +5086,19 @@ async def test_get_execution_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -5564,19 +5107,15 @@ async def test_get_execution_async_from_dict(): def test_get_execution_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetExecutionRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: call.return_value = execution.Execution() client.get_execution(request) @@ -5588,27 +5127,20 @@ def test_get_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_execution_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetExecutionRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) await client.get_execution(request) @@ -5620,99 +5152,79 @@ async def test_get_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_execution_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_execution( - name='name_value', - ) + client.get_execution(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_execution_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_execution( - metadata_service.GetExecutionRequest(), - name='name_value', + metadata_service.GetExecutionRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_execution_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = execution.Execution() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_execution( - name='name_value', - ) + response = await client.get_execution(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_execution_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_execution( - metadata_service.GetExecutionRequest(), - name='name_value', + metadata_service.GetExecutionRequest(), name="name_value", ) -def test_list_executions(transport: str = 'grpc', request_type=metadata_service.ListExecutionsRequest): +def test_list_executions( + transport: str = "grpc", request_type=metadata_service.ListExecutionsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5720,13 +5232,10 @@ def test_list_executions(transport: str = 'grpc', request_type=metadata_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListExecutionsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_executions(request) @@ -5741,7 +5250,7 @@ def test_list_executions(transport: str = 'grpc', request_type=metadata_service. assert isinstance(response, pagers.ListExecutionsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_executions_from_dict(): @@ -5752,25 +5261,24 @@ def test_list_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: client.list_executions() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListExecutionsRequest() + @pytest.mark.asyncio -async def test_list_executions_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListExecutionsRequest): +async def test_list_executions_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListExecutionsRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5778,13 +5286,13 @@ async def test_list_executions_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListExecutionsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_executions(request) @@ -5797,7 +5305,7 @@ async def test_list_executions_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListExecutionsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -5806,19 +5314,15 @@ async def test_list_executions_async_from_dict(): def test_list_executions_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListExecutionsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: call.return_value = metadata_service.ListExecutionsResponse() client.list_executions(request) @@ -5830,28 +5334,23 @@ def test_list_executions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_executions_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListExecutionsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse()) + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListExecutionsResponse() + ) await client.list_executions(request) @@ -5862,104 +5361,81 @@ async def test_list_executions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_executions_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListExecutionsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_executions( - parent='parent_value', - ) + client.list_executions(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_executions_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_executions( - metadata_service.ListExecutionsRequest(), - parent='parent_value', + metadata_service.ListExecutionsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_executions_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListExecutionsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListExecutionsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_executions( - parent='parent_value', - ) + response = await client.list_executions(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_executions_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_executions( - metadata_service.ListExecutionsRequest(), - parent='parent_value', + metadata_service.ListExecutionsRequest(), parent="parent_value", ) def test_list_executions_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -5968,32 +5444,23 @@ def test_list_executions_pager(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_executions(request={}) @@ -6001,18 +5468,14 @@ def test_list_executions_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, execution.Execution) - for i in results) + assert all(isinstance(i, execution.Execution) for i in results) + def test_list_executions_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -6021,40 +5484,32 @@ def test_list_executions_pages(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) pages = list(client.list_executions(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_executions_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_executions), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -6063,46 +5518,37 @@ async def test_list_executions_async_pager(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) async_pager = await client.list_executions(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, execution.Execution) - for i in responses) + assert all(isinstance(i, execution.Execution) for i in responses) + @pytest.mark.asyncio async def test_list_executions_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_executions), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -6111,37 +5557,31 @@ async def test_list_executions_async_pages(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_executions(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_execution(transport: str = 'grpc', request_type=metadata_service.UpdateExecutionRequest): +def test_update_execution( + transport: str = "grpc", request_type=metadata_service.UpdateExecutionRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6149,25 +5589,16 @@ def test_update_execution(transport: str = 'grpc', request_type=metadata_service request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=gca_execution.Execution.State.NEW, - - etag='etag_value', - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.update_execution(request) @@ -6182,19 +5613,19 @@ def test_update_execution(transport: str = 'grpc', request_type=metadata_service assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_update_execution_from_dict(): @@ -6205,25 +5636,25 @@ def test_update_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: client.update_execution() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateExecutionRequest() + @pytest.mark.asyncio -async def test_update_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateExecutionRequest): +async def test_update_execution_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.UpdateExecutionRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6231,19 +5662,19 @@ async def test_update_execution_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution( - name='name_value', - display_name='display_name_value', - state=gca_execution.Execution.State.NEW, - etag='etag_value', - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution( + name="name_value", + display_name="display_name_value", + state=gca_execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.update_execution(request) @@ -6256,19 +5687,19 @@ async def test_update_execution_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -6277,19 +5708,15 @@ async def test_update_execution_async_from_dict(): def test_update_execution_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateExecutionRequest() - request.execution.name = 'execution.name/value' + request.execution.name = "execution.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: call.return_value = gca_execution.Execution() client.update_execution(request) @@ -6301,28 +5728,25 @@ def test_update_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution.name=execution.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_execution_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateExecutionRequest() - request.execution.name = 'execution.name/value' + request.execution.name = "execution.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) await client.update_execution(request) @@ -6333,29 +5757,24 @@ async def test_update_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution.name=execution.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ + "metadata" + ] def test_update_execution_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_execution( - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -6363,45 +5782,41 @@ def test_update_execution_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_execution_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_execution( metadata_service.UpdateExecutionRequest(), - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_execution_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_execution( - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -6409,31 +5824,30 @@ async def test_update_execution_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_execution_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_execution( metadata_service.UpdateExecutionRequest(), - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_add_execution_events(transport: str = 'grpc', request_type=metadata_service.AddExecutionEventsRequest): +def test_add_execution_events( + transport: str = "grpc", request_type=metadata_service.AddExecutionEventsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6442,11 +5856,10 @@ def test_add_execution_events(transport: str = 'grpc', request_type=metadata_ser # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddExecutionEventsResponse( - ) + call.return_value = metadata_service.AddExecutionEventsResponse() response = client.add_execution_events(request) @@ -6469,25 +5882,27 @@ def test_add_execution_events_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: client.add_execution_events() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.AddExecutionEventsRequest() + @pytest.mark.asyncio -async def test_add_execution_events_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddExecutionEventsRequest): +async def test_add_execution_events_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddExecutionEventsRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6496,11 +5911,12 @@ async def test_add_execution_events_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddExecutionEventsResponse() + ) response = await client.add_execution_events(request) @@ -6520,19 +5936,17 @@ async def test_add_execution_events_async_from_dict(): def test_add_execution_events_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddExecutionEventsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: call.return_value = metadata_service.AddExecutionEventsResponse() client.add_execution_events(request) @@ -6544,28 +5958,25 @@ def test_add_execution_events_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] @pytest.mark.asyncio async def test_add_execution_events_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddExecutionEventsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse()) + type(client.transport.add_execution_events), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddExecutionEventsResponse() + ) await client.add_execution_events(request) @@ -6576,29 +5987,24 @@ async def test_add_execution_events_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] def test_add_execution_events_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddExecutionEventsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.add_execution_events( - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) # Establish that the underlying call was made with the expected @@ -6606,45 +6012,43 @@ def test_add_execution_events_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" - assert args[0].events == [event.Event(artifact='artifact_value')] + assert args[0].events == [event.Event(artifact="artifact_value")] def test_add_execution_events_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.add_execution_events( metadata_service.AddExecutionEventsRequest(), - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) @pytest.mark.asyncio async def test_add_execution_events_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddExecutionEventsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddExecutionEventsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.add_execution_events( - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) # Establish that the underlying call was made with the expected @@ -6652,31 +6056,31 @@ async def test_add_execution_events_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" - assert args[0].events == [event.Event(artifact='artifact_value')] + assert args[0].events == [event.Event(artifact="artifact_value")] @pytest.mark.asyncio async def test_add_execution_events_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.add_execution_events( metadata_service.AddExecutionEventsRequest(), - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) -def test_query_execution_inputs_and_outputs(transport: str = 'grpc', request_type=metadata_service.QueryExecutionInputsAndOutputsRequest): +def test_query_execution_inputs_and_outputs( + transport: str = "grpc", + request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6685,11 +6089,10 @@ def test_query_execution_inputs_and_outputs(transport: str = 'grpc', request_typ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph( - ) + call.return_value = lineage_subgraph.LineageSubgraph() response = client.query_execution_inputs_and_outputs(request) @@ -6712,25 +6115,27 @@ def test_query_execution_inputs_and_outputs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: client.query_execution_inputs_and_outputs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + @pytest.mark.asyncio -async def test_query_execution_inputs_and_outputs_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryExecutionInputsAndOutputsRequest): +async def test_query_execution_inputs_and_outputs_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6739,11 +6144,12 @@ async def test_query_execution_inputs_and_outputs_async(transport: str = 'grpc_a # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) response = await client.query_execution_inputs_and_outputs(request) @@ -6763,19 +6169,17 @@ async def test_query_execution_inputs_and_outputs_async_from_dict(): def test_query_execution_inputs_and_outputs_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryExecutionInputsAndOutputsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: call.return_value = lineage_subgraph.LineageSubgraph() client.query_execution_inputs_and_outputs(request) @@ -6787,28 +6191,25 @@ def test_query_execution_inputs_and_outputs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryExecutionInputsAndOutputsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) await client.query_execution_inputs_and_outputs(request) @@ -6819,70 +6220,61 @@ async def test_query_execution_inputs_and_outputs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] def test_query_execution_inputs_and_outputs_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.query_execution_inputs_and_outputs( - execution='execution_value', - ) + client.query_execution_inputs_and_outputs(execution="execution_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" def test_query_execution_inputs_and_outputs_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.query_execution_inputs_and_outputs( metadata_service.QueryExecutionInputsAndOutputsRequest(), - execution='execution_value', + execution="execution_value", ) @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.query_execution_inputs_and_outputs( - execution='execution_value', + execution="execution_value", ) # Establish that the underlying call was made with the expected @@ -6890,28 +6282,27 @@ async def test_query_execution_inputs_and_outputs_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.query_execution_inputs_and_outputs( metadata_service.QueryExecutionInputsAndOutputsRequest(), - execution='execution_value', + execution="execution_value", ) -def test_create_metadata_schema(transport: str = 'grpc', request_type=metadata_service.CreateMetadataSchemaRequest): +def test_create_metadata_schema( + transport: str = "grpc", request_type=metadata_service.CreateMetadataSchemaRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6920,20 +6311,15 @@ def test_create_metadata_schema(transport: str = 'grpc', request_type=metadata_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_metadata_schema.MetadataSchema( - name='name_value', - - schema_version='schema_version_value', - - schema='schema_value', - + name="name_value", + schema_version="schema_version_value", + schema="schema_value", schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - - description='description_value', - + description="description_value", ) response = client.create_metadata_schema(request) @@ -6948,15 +6334,18 @@ def test_create_metadata_schema(transport: str = 'grpc', request_type=metadata_s assert isinstance(response, gca_metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_metadata_schema_from_dict(): @@ -6967,25 +6356,27 @@ def test_create_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: client.create_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateMetadataSchemaRequest() + @pytest.mark.asyncio -async def test_create_metadata_schema_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateMetadataSchemaRequest): +async def test_create_metadata_schema_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.CreateMetadataSchemaRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6994,16 +6385,18 @@ async def test_create_metadata_schema_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema( - name='name_value', - schema_version='schema_version_value', - schema='schema_value', - schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_metadata_schema.MetadataSchema( + name="name_value", + schema_version="schema_version_value", + schema="schema_value", + schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + description="description_value", + ) + ) response = await client.create_metadata_schema(request) @@ -7016,15 +6409,18 @@ async def test_create_metadata_schema_async(transport: str = 'grpc_asyncio', req # Establish that the response is the type that we expect. assert isinstance(response, gca_metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -7033,19 +6429,17 @@ async def test_create_metadata_schema_async_from_dict(): def test_create_metadata_schema_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataSchemaRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: call.return_value = gca_metadata_schema.MetadataSchema() client.create_metadata_schema(request) @@ -7057,28 +6451,25 @@ def test_create_metadata_schema_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_metadata_schema_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataSchemaRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema()) + type(client.transport.create_metadata_schema), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_metadata_schema.MetadataSchema() + ) await client.create_metadata_schema(request) @@ -7089,30 +6480,25 @@ async def test_create_metadata_schema_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_metadata_schema_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_metadata_schema.MetadataSchema() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_metadata_schema( - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) # Establish that the underlying call was made with the expected @@ -7120,49 +6506,49 @@ def test_create_metadata_schema_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema(name='name_value') + assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema( + name="name_value" + ) - assert args[0].metadata_schema_id == 'metadata_schema_id_value' + assert args[0].metadata_schema_id == "metadata_schema_id_value" def test_create_metadata_schema_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_metadata_schema( metadata_service.CreateMetadataSchemaRequest(), - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) @pytest.mark.asyncio async def test_create_metadata_schema_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_metadata_schema.MetadataSchema() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_metadata_schema.MetadataSchema() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_metadata_schema( - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) # Establish that the underlying call was made with the expected @@ -7170,34 +6556,35 @@ async def test_create_metadata_schema_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema(name='name_value') + assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema( + name="name_value" + ) - assert args[0].metadata_schema_id == 'metadata_schema_id_value' + assert args[0].metadata_schema_id == "metadata_schema_id_value" @pytest.mark.asyncio async def test_create_metadata_schema_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_metadata_schema( metadata_service.CreateMetadataSchemaRequest(), - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) -def test_get_metadata_schema(transport: str = 'grpc', request_type=metadata_service.GetMetadataSchemaRequest): +def test_get_metadata_schema( + transport: str = "grpc", request_type=metadata_service.GetMetadataSchemaRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7206,20 +6593,15 @@ def test_get_metadata_schema(transport: str = 'grpc', request_type=metadata_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_schema.MetadataSchema( - name='name_value', - - schema_version='schema_version_value', - - schema='schema_value', - + name="name_value", + schema_version="schema_version_value", + schema="schema_value", schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - - description='description_value', - + description="description_value", ) response = client.get_metadata_schema(request) @@ -7234,15 +6616,18 @@ def test_get_metadata_schema(transport: str = 'grpc', request_type=metadata_serv assert isinstance(response, metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_metadata_schema_from_dict(): @@ -7253,25 +6638,27 @@ def test_get_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: client.get_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetMetadataSchemaRequest() + @pytest.mark.asyncio -async def test_get_metadata_schema_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetMetadataSchemaRequest): +async def test_get_metadata_schema_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.GetMetadataSchemaRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7280,16 +6667,18 @@ async def test_get_metadata_schema_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema( - name='name_value', - schema_version='schema_version_value', - schema='schema_value', - schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_schema.MetadataSchema( + name="name_value", + schema_version="schema_version_value", + schema="schema_value", + schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + description="description_value", + ) + ) response = await client.get_metadata_schema(request) @@ -7302,15 +6691,18 @@ async def test_get_metadata_schema_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -7319,19 +6711,17 @@ async def test_get_metadata_schema_async_from_dict(): def test_get_metadata_schema_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataSchemaRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: call.return_value = metadata_schema.MetadataSchema() client.get_metadata_schema(request) @@ -7343,28 +6733,25 @@ def test_get_metadata_schema_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_metadata_schema_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataSchemaRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema()) + type(client.transport.get_metadata_schema), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_schema.MetadataSchema() + ) await client.get_metadata_schema(request) @@ -7375,99 +6762,85 @@ async def test_get_metadata_schema_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_metadata_schema_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_schema.MetadataSchema() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_metadata_schema( - name='name_value', - ) + client.get_metadata_schema(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_metadata_schema_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_metadata_schema( - metadata_service.GetMetadataSchemaRequest(), - name='name_value', + metadata_service.GetMetadataSchemaRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_metadata_schema_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_schema.MetadataSchema() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_schema.MetadataSchema() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_metadata_schema( - name='name_value', - ) + response = await client.get_metadata_schema(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_metadata_schema_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_metadata_schema( - metadata_service.GetMetadataSchemaRequest(), - name='name_value', + metadata_service.GetMetadataSchemaRequest(), name="name_value", ) -def test_list_metadata_schemas(transport: str = 'grpc', request_type=metadata_service.ListMetadataSchemasRequest): +def test_list_metadata_schemas( + transport: str = "grpc", request_type=metadata_service.ListMetadataSchemasRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7476,12 +6849,11 @@ def test_list_metadata_schemas(transport: str = 'grpc', request_type=metadata_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataSchemasResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_metadata_schemas(request) @@ -7496,7 +6868,7 @@ def test_list_metadata_schemas(transport: str = 'grpc', request_type=metadata_se assert isinstance(response, pagers.ListMetadataSchemasPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_metadata_schemas_from_dict(): @@ -7507,25 +6879,27 @@ def test_list_metadata_schemas_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: client.list_metadata_schemas() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListMetadataSchemasRequest() + @pytest.mark.asyncio -async def test_list_metadata_schemas_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListMetadataSchemasRequest): +async def test_list_metadata_schemas_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.ListMetadataSchemasRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7534,12 +6908,14 @@ async def test_list_metadata_schemas_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataSchemasResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_metadata_schemas(request) @@ -7552,7 +6928,7 @@ async def test_list_metadata_schemas_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListMetadataSchemasAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -7561,19 +6937,17 @@ async def test_list_metadata_schemas_async_from_dict(): def test_list_metadata_schemas_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataSchemasRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: call.return_value = metadata_service.ListMetadataSchemasResponse() client.list_metadata_schemas(request) @@ -7585,28 +6959,25 @@ def test_list_metadata_schemas_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_metadata_schemas_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataSchemasRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse()) + type(client.transport.list_metadata_schemas), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataSchemasResponse() + ) await client.list_metadata_schemas(request) @@ -7617,104 +6988,87 @@ async def test_list_metadata_schemas_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_metadata_schemas_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataSchemasResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_metadata_schemas( - parent='parent_value', - ) + client.list_metadata_schemas(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_metadata_schemas_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_metadata_schemas( - metadata_service.ListMetadataSchemasRequest(), - parent='parent_value', + metadata_service.ListMetadataSchemasRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_metadata_schemas_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataSchemasResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataSchemasResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_metadata_schemas( - parent='parent_value', - ) + response = await client.list_metadata_schemas(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_metadata_schemas_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_metadata_schemas( - metadata_service.ListMetadataSchemasRequest(), - parent='parent_value', + metadata_service.ListMetadataSchemasRequest(), parent="parent_value", ) def test_list_metadata_schemas_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7723,17 +7077,14 @@ def test_list_metadata_schemas_pager(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7746,9 +7097,7 @@ def test_list_metadata_schemas_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_metadata_schemas(request={}) @@ -7756,18 +7105,16 @@ def test_list_metadata_schemas_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, metadata_schema.MetadataSchema) - for i in results) + assert all(isinstance(i, metadata_schema.MetadataSchema) for i in results) + def test_list_metadata_schemas_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7776,17 +7123,14 @@ def test_list_metadata_schemas_pages(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7797,19 +7141,20 @@ def test_list_metadata_schemas_pages(): RuntimeError, ) pages = list(client.list_metadata_schemas(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_metadata_schemas_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_schemas), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7818,17 +7163,14 @@ async def test_list_metadata_schemas_async_pager(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7839,25 +7181,25 @@ async def test_list_metadata_schemas_async_pager(): RuntimeError, ) async_pager = await client.list_metadata_schemas(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, metadata_schema.MetadataSchema) - for i in responses) + assert all(isinstance(i, metadata_schema.MetadataSchema) for i in responses) + @pytest.mark.asyncio async def test_list_metadata_schemas_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_schemas), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7866,17 +7208,14 @@ async def test_list_metadata_schemas_async_pages(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7889,7 +7228,7 @@ async def test_list_metadata_schemas_async_pages(): pages = [] async for page_ in (await client.list_metadata_schemas(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -7900,8 +7239,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -7920,8 +7258,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MetadataServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -7949,13 +7286,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MetadataServiceGrpcTransport, - transports.MetadataServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -7963,13 +7303,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MetadataServiceGrpcTransport, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MetadataServiceGrpcTransport,) def test_metadata_service_base_transport_error(): @@ -7977,13 +7312,15 @@ def test_metadata_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MetadataServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_metadata_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MetadataServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -7992,32 +7329,32 @@ def test_metadata_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_metadata_store', - 'get_metadata_store', - 'list_metadata_stores', - 'delete_metadata_store', - 'create_artifact', - 'get_artifact', - 'list_artifacts', - 'update_artifact', - 'create_context', - 'get_context', - 'list_contexts', - 'update_context', - 'delete_context', - 'add_context_artifacts_and_executions', - 'add_context_children', - 'query_context_lineage_subgraph', - 'create_execution', - 'get_execution', - 'list_executions', - 'update_execution', - 'add_execution_events', - 'query_execution_inputs_and_outputs', - 'create_metadata_schema', - 'get_metadata_schema', - 'list_metadata_schemas', - ) + "create_metadata_store", + "get_metadata_store", + "list_metadata_stores", + "delete_metadata_store", + "create_artifact", + "get_artifact", + "list_artifacts", + "update_artifact", + "create_context", + "get_context", + "list_contexts", + "update_context", + "delete_context", + "add_context_artifacts_and_executions", + "add_context_children", + "query_context_lineage_subgraph", + "create_execution", + "get_execution", + "list_executions", + "update_execution", + "add_execution_events", + "query_execution_inputs_and_outputs", + "create_metadata_schema", + "get_metadata_schema", + "list_metadata_schemas", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -8030,23 +7367,28 @@ def test_metadata_service_base_transport(): def test_metadata_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MetadataServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_metadata_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MetadataServiceTransport() @@ -8055,11 +7397,11 @@ def test_metadata_service_base_transport_with_adc(): def test_metadata_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MetadataServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -8067,19 +7409,25 @@ def test_metadata_service_auth_adc(): def test_metadata_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MetadataServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MetadataServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) -def test_metadata_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) +def test_metadata_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -8088,15 +7436,13 @@ def test_metadata_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -8111,38 +7457,40 @@ def test_metadata_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_metadata_service_host_no_port(): client = MetadataServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_metadata_service_host_with_port(): client = MetadataServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_metadata_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MetadataServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -8150,12 +7498,11 @@ def test_metadata_service_grpc_transport_channel(): def test_metadata_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MetadataServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -8164,12 +7511,22 @@ def test_metadata_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) def test_metadata_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -8178,7 +7535,7 @@ def test_metadata_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -8194,9 +7551,7 @@ def test_metadata_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -8210,17 +7565,23 @@ def test_metadata_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) -def test_metadata_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) +def test_metadata_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -8237,9 +7598,7 @@ def test_metadata_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -8252,16 +7611,12 @@ def test_metadata_service_transport_channel_mtls_with_adc( def test_metadata_service_grpc_lro_client(): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -8269,16 +7624,12 @@ def test_metadata_service_grpc_lro_client(): def test_metadata_service_grpc_lro_async_client(): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -8290,18 +7641,24 @@ def test_artifact_path(): metadata_store = "whelk" artifact = "octopus" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format(project=project, location=location, metadata_store=metadata_store, artifact=artifact, ) - actual = MetadataServiceClient.artifact_path(project, location, metadata_store, artifact) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) + actual = MetadataServiceClient.artifact_path( + project, location, metadata_store, artifact + ) assert expected == actual def test_parse_artifact_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "metadata_store": "cuttlefish", - "artifact": "mussel", - + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "artifact": "mussel", } path = MetadataServiceClient.artifact_path(**expected) @@ -8309,24 +7666,31 @@ def test_parse_artifact_path(): actual = MetadataServiceClient.parse_artifact_path(path) assert expected == actual + def test_context_path(): project = "winkle" location = "nautilus" metadata_store = "scallop" context = "abalone" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format(project=project, location=location, metadata_store=metadata_store, context=context, ) - actual = MetadataServiceClient.context_path(project, location, metadata_store, context) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + actual = MetadataServiceClient.context_path( + project, location, metadata_store, context + ) assert expected == actual def test_parse_context_path(): expected = { - "project": "squid", - "location": "clam", - "metadata_store": "whelk", - "context": "octopus", - + "project": "squid", + "location": "clam", + "metadata_store": "whelk", + "context": "octopus", } path = MetadataServiceClient.context_path(**expected) @@ -8334,24 +7698,31 @@ def test_parse_context_path(): actual = MetadataServiceClient.parse_context_path(path) assert expected == actual + def test_execution_path(): project = "oyster" location = "nudibranch" metadata_store = "cuttlefish" execution = "mussel" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format(project=project, location=location, metadata_store=metadata_store, execution=execution, ) - actual = MetadataServiceClient.execution_path(project, location, metadata_store, execution) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) + actual = MetadataServiceClient.execution_path( + project, location, metadata_store, execution + ) assert expected == actual def test_parse_execution_path(): expected = { - "project": "winkle", - "location": "nautilus", - "metadata_store": "scallop", - "execution": "abalone", - + "project": "winkle", + "location": "nautilus", + "metadata_store": "scallop", + "execution": "abalone", } path = MetadataServiceClient.execution_path(**expected) @@ -8359,24 +7730,31 @@ def test_parse_execution_path(): actual = MetadataServiceClient.parse_execution_path(path) assert expected == actual + def test_metadata_schema_path(): project = "squid" location = "clam" metadata_store = "whelk" metadata_schema = "octopus" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format(project=project, location=location, metadata_store=metadata_store, metadata_schema=metadata_schema, ) - actual = MetadataServiceClient.metadata_schema_path(project, location, metadata_store, metadata_schema) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format( + project=project, + location=location, + metadata_store=metadata_store, + metadata_schema=metadata_schema, + ) + actual = MetadataServiceClient.metadata_schema_path( + project, location, metadata_store, metadata_schema + ) assert expected == actual def test_parse_metadata_schema_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "metadata_store": "cuttlefish", - "metadata_schema": "mussel", - + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "metadata_schema": "mussel", } path = MetadataServiceClient.metadata_schema_path(**expected) @@ -8384,22 +7762,26 @@ def test_parse_metadata_schema_path(): actual = MetadataServiceClient.parse_metadata_schema_path(path) assert expected == actual + def test_metadata_store_path(): project = "winkle" location = "nautilus" metadata_store = "scallop" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format(project=project, location=location, metadata_store=metadata_store, ) - actual = MetadataServiceClient.metadata_store_path(project, location, metadata_store) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format( + project=project, location=location, metadata_store=metadata_store, + ) + actual = MetadataServiceClient.metadata_store_path( + project, location, metadata_store + ) assert expected == actual def test_parse_metadata_store_path(): expected = { - "project": "abalone", - "location": "squid", - "metadata_store": "clam", - + "project": "abalone", + "location": "squid", + "metadata_store": "clam", } path = MetadataServiceClient.metadata_store_path(**expected) @@ -8407,18 +7789,20 @@ def test_parse_metadata_store_path(): actual = MetadataServiceClient.parse_metadata_store_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "whelk" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MetadataServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "octopus", - + "billing_account": "octopus", } path = MetadataServiceClient.common_billing_account_path(**expected) @@ -8426,18 +7810,18 @@ def test_parse_common_billing_account_path(): actual = MetadataServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "oyster" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MetadataServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nudibranch", - + "folder": "nudibranch", } path = MetadataServiceClient.common_folder_path(**expected) @@ -8445,18 +7829,18 @@ def test_parse_common_folder_path(): actual = MetadataServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "cuttlefish" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MetadataServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "mussel", - + "organization": "mussel", } path = MetadataServiceClient.common_organization_path(**expected) @@ -8464,18 +7848,18 @@ def test_parse_common_organization_path(): actual = MetadataServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "winkle" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MetadataServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "nautilus", - + "project": "nautilus", } path = MetadataServiceClient.common_project_path(**expected) @@ -8483,20 +7867,22 @@ def test_parse_common_project_path(): actual = MetadataServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "scallop" location = "abalone" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MetadataServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "squid", - "location": "clam", - + "project": "squid", + "location": "clam", } path = MetadataServiceClient.common_location_path(**expected) @@ -8508,17 +7894,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MetadataServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MetadataServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MetadataServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MetadataServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MetadataServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 85cf790381..635122a0ce 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceClient +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceClient, +) from google.cloud.aiplatform_v1beta1.services.migration_service import pagers from google.cloud.aiplatform_v1beta1.services.migration_service import transports from google.cloud.aiplatform_v1beta1.types import migratable_resource @@ -53,7 +57,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -64,36 +72,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -103,7 +128,7 @@ def test_migration_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_client_get_transport_class(): @@ -117,29 +142,44 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) -def test_migration_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -155,7 +195,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -171,7 +211,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -191,13 +231,15 @@ def test_migration_service_client_client_options(client_class, transport_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -210,26 +252,62 @@ def test_migration_service_client_client_options(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -252,10 +330,18 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -276,9 +362,14 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -292,16 +383,23 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -314,16 +412,24 @@ def test_migration_service_client_client_options_scopes(client_class, transport_ client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -338,10 +444,12 @@ def test_migration_service_client_client_options_credentials_file(client_class, def test_migration_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -354,10 +462,12 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -366,12 +476,11 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_migratable_resources(request) @@ -386,7 +495,7 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_migratable_resources_from_dict(): @@ -397,25 +506,27 @@ def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.SearchMigratableResourcesRequest() + @pytest.mark.asyncio -async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): +async def test_search_migratable_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -424,12 +535,14 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.search_migratable_resources(request) @@ -442,7 +555,7 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -451,19 +564,17 @@ async def test_search_migratable_resources_async_from_dict(): def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -475,10 +586,7 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -490,13 +598,15 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) await client.search_migratable_resources(request) @@ -507,49 +617,39 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources( - parent='parent_value', - ) + client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) @@ -561,24 +661,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources( - parent='parent_value', - ) + response = await client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -591,20 +691,17 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -613,17 +710,14 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -636,9 +730,7 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.search_migratable_resources(request={}) @@ -646,18 +738,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in results) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + def test_search_migratable_resources_pages(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -666,17 +758,14 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -687,19 +776,20 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -708,17 +798,14 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -729,25 +816,27 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in responses) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -756,17 +845,14 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -779,14 +865,15 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -795,10 +882,10 @@ def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_migrate_resources(request) @@ -820,25 +907,27 @@ def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.BatchMigrateResourcesRequest() + @pytest.mark.asyncio -async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): +async def test_batch_migrate_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.BatchMigrateResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -847,11 +936,11 @@ async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_migrate_resources(request) @@ -872,20 +961,18 @@ async def test_batch_migrate_resources_async_from_dict(): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_migrate_resources(request) @@ -896,10 +983,7 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -911,13 +995,15 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_migrate_resources(request) @@ -928,29 +1014,30 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -958,23 +1045,33 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -986,19 +1083,25 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -1006,9 +1109,15 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] @pytest.mark.asyncio @@ -1022,8 +1131,14 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -1034,8 +1149,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1054,8 +1168,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1083,13 +1196,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1097,13 +1213,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MigrationServiceGrpcTransport, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) def test_migration_service_base_transport_error(): @@ -1111,13 +1222,15 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1126,9 +1239,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'search_migratable_resources', - 'batch_migrate_resources', - ) + "search_migratable_resources", + "batch_migrate_resources", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1141,23 +1254,28 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1166,11 +1284,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1178,19 +1296,25 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1199,15 +1323,13 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1222,38 +1344,40 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_migration_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1261,12 +1385,11 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1275,12 +1398,22 @@ def test_migration_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1289,7 +1422,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1305,9 +1438,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1321,17 +1452,23 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1348,9 +1485,7 @@ def test_migration_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1363,16 +1498,12 @@ def test_migration_service_transport_channel_mtls_with_adc( def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1380,16 +1511,12 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1400,17 +1527,20 @@ def test_annotated_dataset_path(): dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) - actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", - + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1418,22 +1548,24 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1441,22 +1573,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "squid" location = "clam" dataset = "whelk" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", - + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", } path = MigrationServiceClient.dataset_path(**expected) @@ -1464,20 +1598,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", - + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1485,22 +1621,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - + "project": "clam", + "location": "whelk", + "model": "octopus", } path = MigrationServiceClient.model_path(**expected) @@ -1508,22 +1646,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", - + "project": "mussel", + "location": "winkle", + "model": "nautilus", } path = MigrationServiceClient.model_path(**expected) @@ -1531,22 +1671,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", - + "project": "clam", + "model": "whelk", + "version": "octopus", } path = MigrationServiceClient.version_path(**expected) @@ -1554,18 +1696,20 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", - + "billing_account": "nudibranch", } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1573,18 +1717,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", - + "folder": "mussel", } path = MigrationServiceClient.common_folder_path(**expected) @@ -1592,18 +1736,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", - + "organization": "nautilus", } path = MigrationServiceClient.common_organization_path(**expected) @@ -1611,18 +1755,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", - + "project": "abalone", } path = MigrationServiceClient.common_project_path(**expected) @@ -1630,20 +1774,22 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", - + "project": "whelk", + "location": "octopus", } path = MigrationServiceClient.common_location_path(**expected) @@ -1655,17 +1801,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index ffe3ecd828..a31f13c873 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.model_service import ( + ModelServiceAsyncClient, +) from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.services.model_service import transports @@ -66,7 +68,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -77,36 +83,45 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [ - ModelServiceClient, - ModelServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) def test_model_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - ModelServiceClient, - ModelServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -116,7 +131,7 @@ def test_model_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_client_get_transport_class(): @@ -130,29 +145,42 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -def test_model_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -168,7 +196,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -184,7 +212,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -204,13 +232,15 @@ def test_model_service_client_client_options(client_class, transport_class, tran client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -223,26 +253,50 @@ def test_model_service_client_client_options(client_class, transport_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -265,10 +319,18 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -289,9 +351,14 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -305,16 +372,23 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -327,16 +401,24 @@ def test_model_service_client_client_options_scopes(client_class, transport_clas client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -351,11 +433,11 @@ def test_model_service_client_client_options_credentials_file(client_class, tran def test_model_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -367,10 +449,11 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): +def test_upload_model( + transport: str = "grpc", request_type=model_service.UploadModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -378,11 +461,9 @@ def test_upload_model(transport: str = 'grpc', request_type=model_service.Upload request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.upload_model(request) @@ -404,25 +485,24 @@ def test_upload_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: client.upload_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UploadModelRequest() + @pytest.mark.asyncio -async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UploadModelRequest): +async def test_upload_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -430,12 +510,10 @@ async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.upload_model(request) @@ -456,20 +534,16 @@ async def test_upload_model_async_from_dict(): def test_upload_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.upload_model(request) @@ -480,28 +554,23 @@ def test_upload_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.upload_model(request) @@ -512,29 +581,21 @@ async def test_upload_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_upload_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -542,47 +603,40 @@ def test_upload_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") def test_upload_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) @pytest.mark.asyncio async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -590,31 +644,28 @@ async def test_upload_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") @pytest.mark.asyncio async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) -def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): +def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -622,31 +673,21 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.get_model(request) @@ -661,25 +702,31 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_model_from_dict(): @@ -690,25 +737,24 @@ def test_get_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: client.get_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelRequest() + @pytest.mark.asyncio -async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): +async def test_get_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -716,22 +762,28 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.get_model(request) @@ -744,25 +796,31 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -771,19 +829,15 @@ async def test_get_model_async_from_dict(): def test_get_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = model.Model() client.get_model(request) @@ -795,27 +849,20 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -827,99 +874,79 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model( - name='name_value', - ) + client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model( - name='name_value', - ) + response = await client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) -def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): +def test_list_models( + transport: str = "grpc", request_type=model_service.ListModelsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -927,13 +954,10 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_models(request) @@ -948,7 +972,7 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_models_from_dict(): @@ -959,25 +983,24 @@ def test_list_models_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: client.list_models() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelsRequest() + @pytest.mark.asyncio -async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): +async def test_list_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -985,13 +1008,11 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_models(request) @@ -1004,7 +1025,7 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1013,19 +1034,15 @@ async def test_list_models_async_from_dict(): def test_list_models_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: call.return_value = model_service.ListModelsResponse() client.list_models(request) @@ -1037,28 +1054,23 @@ def test_list_models_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) await client.list_models(request) @@ -1069,138 +1081,98 @@ async def test_list_models_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_models_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_models( - parent='parent_value', - ) + client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_models_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_models( - parent='parent_value', - ) + response = await client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) def test_list_models_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_models(request={}) @@ -1208,147 +1180,96 @@ def test_list_models_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) + assert all(isinstance(i, model.Model) for i in results) + def test_list_models_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_models_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) - for i in responses) + assert all(isinstance(i, model.Model) for i in responses) + @pytest.mark.asyncio async def test_list_models_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_models(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): +def test_update_model( + transport: str = "grpc", request_type=model_service.UpdateModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1356,31 +1277,21 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.update_model(request) @@ -1395,25 +1306,31 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_model_from_dict(): @@ -1424,25 +1341,24 @@ def test_update_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: client.update_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UpdateModelRequest() + @pytest.mark.asyncio -async def test_update_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateModelRequest): +async def test_update_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1450,22 +1366,28 @@ async def test_update_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.update_model(request) @@ -1478,25 +1400,31 @@ async def test_update_model_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -1505,19 +1433,15 @@ async def test_update_model_async_from_dict(): def test_update_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = gca_model.Model() client.update_model(request) @@ -1529,27 +1453,20 @@ def test_update_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) await client.update_model(request) @@ -1561,29 +1478,22 @@ async def test_update_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] def test_update_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1591,36 +1501,30 @@ def test_update_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() @@ -1628,8 +1532,8 @@ async def test_update_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1637,31 +1541,30 @@ async def test_update_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): +def test_delete_model( + transport: str = "grpc", request_type=model_service.DeleteModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1669,11 +1572,9 @@ def test_delete_model(transport: str = 'grpc', request_type=model_service.Delete request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_model(request) @@ -1695,25 +1596,24 @@ def test_delete_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: client.delete_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.DeleteModelRequest() + @pytest.mark.asyncio -async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteModelRequest): +async def test_delete_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1721,12 +1621,10 @@ async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_model(request) @@ -1747,20 +1645,16 @@ async def test_delete_model_async_from_dict(): def test_delete_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_model(request) @@ -1771,28 +1665,23 @@ def test_delete_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_model(request) @@ -1803,101 +1692,81 @@ async def test_delete_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model( - name='name_value', - ) + client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_model( - name='name_value', - ) + response = await client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) -def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): +def test_export_model( + transport: str = "grpc", request_type=model_service.ExportModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1905,11 +1774,9 @@ def test_export_model(transport: str = 'grpc', request_type=model_service.Export request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_model(request) @@ -1931,25 +1798,24 @@ def test_export_model_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: client.export_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ExportModelRequest() + @pytest.mark.asyncio -async def test_export_model_async(transport: str = 'grpc_asyncio', request_type=model_service.ExportModelRequest): +async def test_export_model_async( + transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1957,12 +1823,10 @@ async def test_export_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_model(request) @@ -1983,20 +1847,16 @@ async def test_export_model_async_from_dict(): def test_export_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_model(request) @@ -2007,28 +1867,23 @@ def test_export_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_model(request) @@ -2039,29 +1894,24 @@ async def test_export_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_export_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -2069,47 +1919,47 @@ def test_export_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) def test_export_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) -@pytest.mark.asyncio -async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) +@pytest.mark.asyncio +async def test_export_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -2117,31 +1967,34 @@ async def test_export_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) @pytest.mark.asyncio async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) -def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): +def test_get_model_evaluation( + transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2150,16 +2003,13 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - - slice_dimensions=['slice_dimensions_value'], - + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], ) response = client.get_model_evaluation(request) @@ -2174,11 +2024,11 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] def test_get_model_evaluation_from_dict(): @@ -2189,25 +2039,27 @@ def test_get_model_evaluation_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: client.get_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationRequest() + @pytest.mark.asyncio -async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationRequest): +async def test_get_model_evaluation_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2216,14 +2068,16 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - slice_dimensions=['slice_dimensions_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + ) response = await client.get_model_evaluation(request) @@ -2236,11 +2090,11 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', reque # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] @pytest.mark.asyncio @@ -2249,19 +2103,17 @@ async def test_get_model_evaluation_async_from_dict(): def test_get_model_evaluation_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: call.return_value = model_evaluation.ModelEvaluation() client.get_model_evaluation(request) @@ -2273,28 +2125,25 @@ def test_get_model_evaluation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) await client.get_model_evaluation(request) @@ -2305,99 +2154,85 @@ async def test_get_model_evaluation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation( - name='name_value', - ) + client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation( - name='name_value', - ) + response = await client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) -def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): +def test_list_model_evaluations( + transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2406,12 +2241,11 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluations(request) @@ -2426,7 +2260,7 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluations_from_dict(): @@ -2437,25 +2271,27 @@ def test_list_model_evaluations_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: client.list_model_evaluations() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationsRequest() + @pytest.mark.asyncio -async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationsRequest): +async def test_list_model_evaluations_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationsRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2464,12 +2300,14 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluations(request) @@ -2482,7 +2320,7 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', req # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2491,19 +2329,17 @@ async def test_list_model_evaluations_async_from_dict(): def test_list_model_evaluations_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationsResponse() client.list_model_evaluations(request) @@ -2515,28 +2351,25 @@ def test_list_model_evaluations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) await client.list_model_evaluations(request) @@ -2547,104 +2380,87 @@ async def test_list_model_evaluations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluations_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluations( - parent='parent_value', - ) + client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluations( - parent='parent_value', - ) + response = await client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) def test_list_model_evaluations_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2653,17 +2469,14 @@ def test_list_model_evaluations_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2676,9 +2489,7 @@ def test_list_model_evaluations_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluations(request={}) @@ -2686,18 +2497,16 @@ def test_list_model_evaluations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in results) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) + def test_list_model_evaluations_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2706,17 +2515,14 @@ def test_list_model_evaluations_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2727,19 +2533,20 @@ def test_list_model_evaluations_pages(): RuntimeError, ) pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2748,17 +2555,14 @@ async def test_list_model_evaluations_async_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2769,25 +2573,25 @@ async def test_list_model_evaluations_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in responses) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) + @pytest.mark.asyncio async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2796,17 +2600,14 @@ async def test_list_model_evaluations_async_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2819,14 +2620,15 @@ async def test_list_model_evaluations_async_pages(): pages = [] async for page_ in (await client.list_model_evaluations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): +def test_get_model_evaluation_slice( + transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2835,14 +2637,11 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - + name="name_value", metrics_schema_uri="metrics_schema_uri_value", ) response = client.get_model_evaluation_slice(request) @@ -2857,9 +2656,9 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" def test_get_model_evaluation_slice_from_dict(): @@ -2870,25 +2669,27 @@ def test_get_model_evaluation_slice_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: client.get_model_evaluation_slice() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelEvaluationSliceRequest() + @pytest.mark.asyncio -async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationSliceRequest): +async def test_get_model_evaluation_slice_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationSliceRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2897,13 +2698,14 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + ) response = await client.get_model_evaluation_slice(request) @@ -2916,9 +2718,9 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" @pytest.mark.asyncio @@ -2927,19 +2729,17 @@ async def test_get_model_evaluation_slice_async_from_dict(): def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: call.return_value = model_evaluation_slice.ModelEvaluationSlice() client.get_model_evaluation_slice(request) @@ -2951,28 +2751,25 @@ def test_get_model_evaluation_slice_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) await client.get_model_evaluation_slice(request) @@ -2983,99 +2780,85 @@ async def test_get_model_evaluation_slice_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation_slice( - name='name_value', - ) + client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice( - name='name_value', - ) + response = await client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) -def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): +def test_list_model_evaluation_slices( + transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3084,12 +2867,11 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluation_slices(request) @@ -3104,7 +2886,7 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluation_slices_from_dict(): @@ -3115,25 +2897,27 @@ def test_list_model_evaluation_slices_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: client.list_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelEvaluationSlicesRequest() + @pytest.mark.asyncio -async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationSlicesRequest): +async def test_list_model_evaluation_slices_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationSlicesRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3142,12 +2926,14 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluation_slices(request) @@ -3160,7 +2946,7 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3169,19 +2955,17 @@ async def test_list_model_evaluation_slices_async_from_dict(): def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationSlicesResponse() client.list_model_evaluation_slices(request) @@ -3193,28 +2977,25 @@ def test_list_model_evaluation_slices_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) await client.list_model_evaluation_slices(request) @@ -3225,104 +3006,87 @@ async def test_list_model_evaluation_slices_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluation_slices( - parent='parent_value', - ) + client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices( - parent='parent_value', - ) + response = await client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3331,17 +3095,16 @@ def test_list_model_evaluation_slices_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3354,9 +3117,7 @@ def test_list_model_evaluation_slices_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluation_slices(request={}) @@ -3364,18 +3125,18 @@ def test_list_model_evaluation_slices_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in results) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results + ) + def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3384,17 +3145,16 @@ def test_list_model_evaluation_slices_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3405,19 +3165,20 @@ def test_list_model_evaluation_slices_pages(): RuntimeError, ) pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3426,17 +3187,16 @@ async def test_list_model_evaluation_slices_async_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3447,25 +3207,28 @@ async def test_list_model_evaluation_slices_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses + ) + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3474,17 +3237,16 @@ async def test_list_model_evaluation_slices_async_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3495,9 +3257,11 @@ async def test_list_model_evaluation_slices_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_model_evaluation_slices(request={})).pages: + async for page_ in ( + await client.list_model_evaluation_slices(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3508,8 +3272,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3528,8 +3291,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3557,13 +3319,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3571,13 +3336,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.ModelServiceGrpcTransport, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) def test_model_service_base_transport_error(): @@ -3585,13 +3345,15 @@ def test_model_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3600,17 +3362,17 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'upload_model', - 'get_model', - 'list_models', - 'update_model', - 'delete_model', - 'export_model', - 'get_model_evaluation', - 'list_model_evaluations', - 'get_model_evaluation_slice', - 'list_model_evaluation_slices', - ) + "upload_model", + "get_model", + "list_models", + "update_model", + "delete_model", + "export_model", + "get_model_evaluation", + "list_model_evaluations", + "get_model_evaluation_slice", + "list_model_evaluation_slices", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3623,23 +3385,28 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -3648,11 +3415,11 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) ModelServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3660,19 +3427,22 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3681,15 +3451,13 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3704,38 +3472,40 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_host_with_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_model_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3743,12 +3513,11 @@ def test_model_service_grpc_transport_channel(): def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3757,12 +3526,17 @@ def test_model_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3771,7 +3545,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3787,9 +3561,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3803,17 +3575,20 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3830,9 +3605,7 @@ def test_model_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3845,16 +3618,12 @@ def test_model_service_transport_channel_mtls_with_adc( def test_model_service_grpc_lro_client(): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3862,16 +3631,12 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3882,17 +3647,18 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = ModelServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = ModelServiceClient.endpoint_path(**expected) @@ -3900,22 +3666,24 @@ def test_parse_endpoint_path(): actual = ModelServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = ModelServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = ModelServiceClient.model_path(**expected) @@ -3923,24 +3691,28 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual + def test_model_evaluation_path(): project = "squid" location = "clam" model = "whelk" evaluation = "octopus" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) - actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + actual = ModelServiceClient.model_evaluation_path( + project, location, model, evaluation + ) assert expected == actual def test_parse_model_evaluation_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "model": "cuttlefish", - "evaluation": "mussel", - + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", } path = ModelServiceClient.model_evaluation_path(**expected) @@ -3948,6 +3720,7 @@ def test_parse_model_evaluation_path(): actual = ModelServiceClient.parse_model_evaluation_path(path) assert expected == actual + def test_model_evaluation_slice_path(): project = "winkle" location = "nautilus" @@ -3955,19 +3728,26 @@ def test_model_evaluation_slice_path(): evaluation = "abalone" slice = "squid" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) - actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + actual = ModelServiceClient.model_evaluation_slice_path( + project, location, model, evaluation, slice + ) assert expected == actual def test_parse_model_evaluation_slice_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - "evaluation": "oyster", - "slice": "nudibranch", - + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", } path = ModelServiceClient.model_evaluation_slice_path(**expected) @@ -3975,22 +3755,26 @@ def test_parse_model_evaluation_slice_path(): actual = ModelServiceClient.parse_model_evaluation_slice_path(path) assert expected == actual + def test_training_pipeline_path(): project = "cuttlefish" location = "mussel" training_pipeline = "winkle" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = ModelServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "nautilus", - "location": "scallop", - "training_pipeline": "abalone", - + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", } path = ModelServiceClient.training_pipeline_path(**expected) @@ -3998,18 +3782,20 @@ def test_parse_training_pipeline_path(): actual = ModelServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = ModelServiceClient.common_billing_account_path(**expected) @@ -4017,18 +3803,18 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = ModelServiceClient.common_folder_path(**expected) @@ -4036,18 +3822,18 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = ModelServiceClient.common_organization_path(**expected) @@ -4055,18 +3841,18 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = ModelServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = ModelServiceClient.common_project_path(**expected) @@ -4074,20 +3860,22 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = ModelServiceClient.common_location_path(**expected) @@ -4099,17 +3887,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index be11879c35..e353077d80 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceClient +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceClient, +) from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports from google.cloud.aiplatform_v1beta1.types import deployed_model_ref @@ -50,7 +54,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import any_pb2 as gp_any # type: ignore @@ -68,7 +74,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -79,36 +89,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - PipelineServiceClient, - PipelineServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] +) def test_pipeline_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - PipelineServiceClient, - PipelineServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] +) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -118,7 +144,7 @@ def test_pipeline_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_client_get_transport_class(): @@ -132,29 +158,44 @@ def test_pipeline_service_client_get_transport_class(): assert transport == transports.PipelineServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) -def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -170,7 +211,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -186,7 +227,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -206,13 +247,15 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -225,26 +268,62 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "true", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "false", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_pipeline_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -267,10 +346,18 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -291,9 +378,14 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -307,16 +399,23 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -329,16 +428,24 @@ def test_pipeline_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -353,10 +460,12 @@ def test_pipeline_service_client_client_options_credentials_file(client_class, t def test_pipeline_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = PipelineServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -369,10 +478,11 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): +def test_create_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -381,18 +491,14 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.create_training_pipeline(request) @@ -407,11 +513,11 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -424,25 +530,27 @@ def test_create_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: client.create_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CreateTrainingPipelineRequest): +async def test_create_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CreateTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -451,15 +559,17 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.create_training_pipeline(request) @@ -472,11 +582,11 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -487,19 +597,17 @@ async def test_create_training_pipeline_async_from_dict(): def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: call.return_value = gca_training_pipeline.TrainingPipeline() client.create_training_pipeline(request) @@ -511,28 +619,25 @@ def test_create_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) await client.create_training_pipeline(request) @@ -543,29 +648,24 @@ async def test_create_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -573,45 +673,45 @@ def test_create_training_pipeline_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -619,31 +719,32 @@ async def test_create_training_pipeline_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) -def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): +def test_get_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -652,18 +753,14 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.get_training_pipeline(request) @@ -678,11 +775,11 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -695,25 +792,27 @@ def test_get_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: client.get_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.GetTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.GetTrainingPipelineRequest): +async def test_get_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.GetTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -722,15 +821,17 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.get_training_pipeline(request) @@ -743,11 +844,11 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -758,19 +859,17 @@ async def test_get_training_pipeline_async_from_dict(): def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: call.return_value = training_pipeline.TrainingPipeline() client.get_training_pipeline(request) @@ -782,28 +881,25 @@ def test_get_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) await client.get_training_pipeline(request) @@ -814,99 +910,85 @@ async def test_get_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_training_pipeline( - name='name_value', - ) + client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_training_pipeline( - name='name_value', - ) + response = await client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) -def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): +def test_list_training_pipelines( + transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -915,12 +997,11 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_training_pipelines(request) @@ -935,7 +1016,7 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_training_pipelines_from_dict(): @@ -946,25 +1027,27 @@ def test_list_training_pipelines_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: client.list_training_pipelines() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + @pytest.mark.asyncio -async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.ListTrainingPipelinesRequest): +async def test_list_training_pipelines_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.ListTrainingPipelinesRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -973,12 +1056,14 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_training_pipelines(request) @@ -991,7 +1076,7 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', re # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1000,19 +1085,17 @@ async def test_list_training_pipelines_async_from_dict(): def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: call.return_value = pipeline_service.ListTrainingPipelinesResponse() client.list_training_pipelines(request) @@ -1024,28 +1107,25 @@ def test_list_training_pipelines_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) await client.list_training_pipelines(request) @@ -1056,104 +1136,87 @@ async def test_list_training_pipelines_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_training_pipelines_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_training_pipelines( - parent='parent_value', - ) + client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_training_pipelines( - parent='parent_value', - ) + response = await client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) def test_list_training_pipelines_pager(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1162,17 +1225,14 @@ def test_list_training_pipelines_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1185,9 +1245,7 @@ def test_list_training_pipelines_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_training_pipelines(request={}) @@ -1195,18 +1253,16 @@ def test_list_training_pipelines_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in results) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) + def test_list_training_pipelines_pages(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1215,17 +1271,14 @@ def test_list_training_pipelines_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1236,19 +1289,20 @@ def test_list_training_pipelines_pages(): RuntimeError, ) pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1257,17 +1311,14 @@ async def test_list_training_pipelines_async_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1278,25 +1329,25 @@ async def test_list_training_pipelines_async_pager(): RuntimeError, ) async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in responses) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) + @pytest.mark.asyncio async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1305,17 +1356,14 @@ async def test_list_training_pipelines_async_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1328,14 +1376,15 @@ async def test_list_training_pipelines_async_pages(): pages = [] async for page_ in (await client.list_training_pipelines(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): +def test_delete_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1344,10 +1393,10 @@ def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_training_pipeline(request) @@ -1369,25 +1418,27 @@ def test_delete_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: client.delete_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.DeleteTrainingPipelineRequest): +async def test_delete_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.DeleteTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1396,11 +1447,11 @@ async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_training_pipeline(request) @@ -1421,20 +1472,18 @@ async def test_delete_training_pipeline_async_from_dict(): def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_training_pipeline(request) @@ -1445,28 +1494,25 @@ def test_delete_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_training_pipeline(request) @@ -1477,101 +1523,85 @@ async def test_delete_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_training_pipeline( - name='name_value', - ) + client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_training_pipeline( - name='name_value', - ) + response = await client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) -def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): +def test_cancel_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1580,8 +1610,8 @@ def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1605,25 +1635,27 @@ def test_cancel_training_pipeline_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: client.cancel_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + @pytest.mark.asyncio -async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CancelTrainingPipelineRequest): +async def test_cancel_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CancelTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1632,8 +1664,8 @@ async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1655,19 +1687,17 @@ async def test_cancel_training_pipeline_async_from_dict(): def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = None client.cancel_training_pipeline(request) @@ -1679,27 +1709,22 @@ def test_cancel_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_training_pipeline(request) @@ -1711,92 +1736,75 @@ async def test_cancel_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_training_pipeline( - name='name_value', - ) + client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_training_pipeline( - name='name_value', - ) + response = await client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @@ -1807,8 +1815,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1827,8 +1834,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1856,13 +1862,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1870,13 +1879,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.PipelineServiceGrpcTransport, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) def test_pipeline_service_base_transport_error(): @@ -1884,13 +1888,15 @@ def test_pipeline_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_pipeline_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1899,12 +1905,12 @@ def test_pipeline_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_training_pipeline', - 'get_training_pipeline', - 'list_training_pipelines', - 'delete_training_pipeline', - 'cancel_training_pipeline', - ) + "create_training_pipeline", + "get_training_pipeline", + "list_training_pipelines", + "delete_training_pipeline", + "cancel_training_pipeline", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1917,23 +1923,28 @@ def test_pipeline_service_base_transport(): def test_pipeline_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_pipeline_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport() @@ -1942,11 +1953,11 @@ def test_pipeline_service_base_transport_with_adc(): def test_pipeline_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PipelineServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1954,19 +1965,25 @@ def test_pipeline_service_auth_adc(): def test_pipeline_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) -def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1975,15 +1992,13 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1998,38 +2013,40 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_host_with_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_pipeline_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2037,12 +2054,11 @@ def test_pipeline_service_grpc_transport_channel(): def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2051,12 +2067,22 @@ def test_pipeline_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2065,7 +2091,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2081,9 +2107,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2097,17 +2121,23 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) -def test_pipeline_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2124,9 +2154,7 @@ def test_pipeline_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2139,16 +2167,12 @@ def test_pipeline_service_transport_channel_mtls_with_adc( def test_pipeline_service_grpc_lro_client(): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2156,16 +2180,12 @@ def test_pipeline_service_grpc_lro_client(): def test_pipeline_service_grpc_lro_async_client(): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2176,17 +2196,18 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = PipelineServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = PipelineServiceClient.endpoint_path(**expected) @@ -2194,22 +2215,24 @@ def test_parse_endpoint_path(): actual = PipelineServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = PipelineServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = PipelineServiceClient.model_path(**expected) @@ -2217,22 +2240,26 @@ def test_parse_model_path(): actual = PipelineServiceClient.parse_model_path(path) assert expected == actual + def test_training_pipeline_path(): project = "squid" location = "clam" training_pipeline = "whelk" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = PipelineServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", - + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2240,18 +2267,20 @@ def test_parse_training_pipeline_path(): actual = PipelineServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = PipelineServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2259,18 +2288,18 @@ def test_parse_common_billing_account_path(): actual = PipelineServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = PipelineServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = PipelineServiceClient.common_folder_path(**expected) @@ -2278,18 +2307,18 @@ def test_parse_common_folder_path(): actual = PipelineServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = PipelineServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = PipelineServiceClient.common_organization_path(**expected) @@ -2297,18 +2326,18 @@ def test_parse_common_organization_path(): actual = PipelineServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = PipelineServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = PipelineServiceClient.common_project_path(**expected) @@ -2316,20 +2345,22 @@ def test_parse_common_project_path(): actual = PipelineServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = PipelineServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = PipelineServiceClient.common_location_path(**expected) @@ -2341,17 +2372,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = PipelineServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index 06ec395aaf..3daed56994 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceClient +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceClient, +) from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -56,7 +60,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -67,36 +75,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - SpecialistPoolServiceClient, - SpecialistPoolServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] +) def test_specialist_pool_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - SpecialistPoolServiceClient, - SpecialistPoolServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] +) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -106,7 +131,7 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_client_get_transport_class(): @@ -120,29 +145,48 @@ def test_specialist_pool_service_client_get_transport_class(): assert transport == transports.SpecialistPoolServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) -def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -158,7 +202,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -174,7 +218,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -194,13 +238,15 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -213,26 +259,62 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "true", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "false", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_specialist_pool_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -255,10 +337,18 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -279,9 +369,14 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -295,16 +390,27 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -317,16 +423,28 @@ def test_specialist_pool_service_client_client_options_scopes(client_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -341,10 +459,12 @@ def test_specialist_pool_service_client_client_options_credentials_file(client_c def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = SpecialistPoolServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -357,10 +477,12 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): +def test_create_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -369,10 +491,10 @@ def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_specialist_pool(request) @@ -394,25 +516,27 @@ def test_create_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: client.create_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.CreateSpecialistPoolRequest): +async def test_create_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -421,11 +545,11 @@ async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_specialist_pool(request) @@ -453,13 +577,13 @@ def test_create_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_specialist_pool(request) @@ -470,10 +594,7 @@ def test_create_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -485,13 +606,15 @@ async def test_create_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_specialist_pool(request) @@ -502,10 +625,7 @@ async def test_create_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_specialist_pool_flattened(): @@ -515,16 +635,16 @@ def test_create_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -532,9 +652,11 @@ def test_create_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) def test_create_specialist_pool_flattened_error(): @@ -547,8 +669,8 @@ def test_create_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) @@ -560,19 +682,19 @@ async def test_create_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -580,9 +702,11 @@ async def test_create_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) @pytest.mark.asyncio @@ -596,15 +720,17 @@ async def test_create_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) -def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): +def test_get_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -613,20 +739,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", specialist_managers_count=2662, - - specialist_manager_emails=['specialist_manager_emails_value'], - - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], ) response = client.get_specialist_pool(request) @@ -641,15 +762,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] def test_get_specialist_pool_from_dict(): @@ -660,25 +781,27 @@ def test_get_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: client.get_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.GetSpecialistPoolRequest): +async def test_get_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -687,16 +810,18 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( - name='name_value', - display_name='display_name_value', - specialist_managers_count=2662, - specialist_manager_emails=['specialist_manager_emails_value'], - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + ) response = await client.get_specialist_pool(request) @@ -709,15 +834,15 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] @pytest.mark.asyncio @@ -733,12 +858,12 @@ def test_get_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: call.return_value = specialist_pool.SpecialistPool() client.get_specialist_pool(request) @@ -750,10 +875,7 @@ def test_get_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -765,13 +887,15 @@ async def test_get_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) await client.get_specialist_pool(request) @@ -782,10 +906,7 @@ async def test_get_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_specialist_pool_flattened(): @@ -795,23 +916,21 @@ def test_get_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_specialist_pool( - name='name_value', - ) + client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_specialist_pool_flattened_error(): @@ -823,8 +942,7 @@ def test_get_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) @@ -836,24 +954,24 @@ async def test_get_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_specialist_pool( - name='name_value', - ) + response = await client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -866,15 +984,16 @@ async def test_get_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) -def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): +def test_list_specialist_pools( + transport: str = "grpc", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -883,12 +1002,11 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_specialist_pools(request) @@ -903,7 +1021,7 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_specialist_pools_from_dict(): @@ -914,25 +1032,27 @@ def test_list_specialist_pools_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: client.list_specialist_pools() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + @pytest.mark.asyncio -async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.ListSpecialistPoolsRequest): +async def test_list_specialist_pools_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -941,12 +1061,14 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_specialist_pools(request) @@ -959,7 +1081,7 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -975,12 +1097,12 @@ def test_list_specialist_pools_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() client.list_specialist_pools(request) @@ -992,10 +1114,7 @@ def test_list_specialist_pools_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1007,13 +1126,15 @@ async def test_list_specialist_pools_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) await client.list_specialist_pools(request) @@ -1024,10 +1145,7 @@ async def test_list_specialist_pools_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_specialist_pools_flattened(): @@ -1037,23 +1155,21 @@ def test_list_specialist_pools_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_specialist_pools( - parent='parent_value', - ) + client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_specialist_pools_flattened_error(): @@ -1065,8 +1181,7 @@ def test_list_specialist_pools_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) @@ -1078,24 +1193,24 @@ async def test_list_specialist_pools_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_specialist_pools( - parent='parent_value', - ) + response = await client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -1108,20 +1223,17 @@ async def test_list_specialist_pools_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1130,17 +1242,14 @@ def test_list_specialist_pools_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1153,9 +1262,7 @@ def test_list_specialist_pools_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_specialist_pools(request={}) @@ -1163,18 +1270,16 @@ def test_list_specialist_pools_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in results) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) + def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1183,17 +1288,14 @@ def test_list_specialist_pools_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1204,9 +1306,10 @@ def test_list_specialist_pools_pages(): RuntimeError, ) pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_specialist_pools_async_pager(): client = SpecialistPoolServiceAsyncClient( @@ -1215,8 +1318,10 @@ async def test_list_specialist_pools_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1225,17 +1330,14 @@ async def test_list_specialist_pools_async_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1246,14 +1348,14 @@ async def test_list_specialist_pools_async_pager(): RuntimeError, ) async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in responses) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) + @pytest.mark.asyncio async def test_list_specialist_pools_async_pages(): @@ -1263,8 +1365,10 @@ async def test_list_specialist_pools_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1273,17 +1377,14 @@ async def test_list_specialist_pools_async_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1296,14 +1397,16 @@ async def test_list_specialist_pools_async_pages(): pages = [] async for page_ in (await client.list_specialist_pools(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): +def test_delete_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1312,10 +1415,10 @@ def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_specialist_pool(request) @@ -1337,25 +1440,27 @@ def test_delete_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: client.delete_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): +async def test_delete_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1364,11 +1469,11 @@ async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_specialist_pool(request) @@ -1396,13 +1501,13 @@ def test_delete_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_specialist_pool(request) @@ -1413,10 +1518,7 @@ def test_delete_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1428,13 +1530,15 @@ async def test_delete_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_specialist_pool(request) @@ -1445,10 +1549,7 @@ async def test_delete_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_specialist_pool_flattened(): @@ -1458,23 +1559,21 @@ def test_delete_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_specialist_pool( - name='name_value', - ) + client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_specialist_pool_flattened_error(): @@ -1486,8 +1585,7 @@ def test_delete_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) @@ -1499,26 +1597,24 @@ async def test_delete_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_specialist_pool( - name='name_value', - ) + response = await client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -1531,15 +1627,16 @@ async def test_delete_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) -def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): +def test_update_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1548,10 +1645,10 @@ def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.update_specialist_pool(request) @@ -1573,25 +1670,27 @@ def test_update_specialist_pool_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: client.update_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + @pytest.mark.asyncio -async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): +async def test_update_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1600,11 +1699,11 @@ async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.update_specialist_pool(request) @@ -1632,13 +1731,13 @@ def test_update_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.update_specialist_pool(request) @@ -1650,9 +1749,9 @@ def test_update_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1664,13 +1763,15 @@ async def test_update_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.update_specialist_pool(request) @@ -1682,9 +1783,9 @@ async def test_update_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] def test_update_specialist_pool_flattened(): @@ -1694,16 +1795,16 @@ def test_update_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1711,9 +1812,11 @@ def test_update_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_specialist_pool_flattened_error(): @@ -1726,8 +1829,8 @@ def test_update_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1739,19 +1842,19 @@ async def test_update_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1759,9 +1862,11 @@ async def test_update_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -1775,8 +1880,8 @@ async def test_update_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1787,8 +1892,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1807,8 +1911,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1836,13 +1939,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1853,10 +1959,7 @@ def test_transport_grpc_default(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), ) - assert isinstance( - client.transport, - transports.SpecialistPoolServiceGrpcTransport, - ) + assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) def test_specialist_pool_service_base_transport_error(): @@ -1864,13 +1967,15 @@ def test_specialist_pool_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_specialist_pool_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1879,12 +1984,12 @@ def test_specialist_pool_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_specialist_pool', - 'get_specialist_pool', - 'list_specialist_pools', - 'delete_specialist_pool', - 'update_specialist_pool', - ) + "create_specialist_pool", + "get_specialist_pool", + "list_specialist_pools", + "delete_specialist_pool", + "update_specialist_pool", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1897,23 +2002,28 @@ def test_specialist_pool_service_base_transport(): def test_specialist_pool_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_specialist_pool_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport() @@ -1922,11 +2032,11 @@ def test_specialist_pool_service_base_transport_with_adc(): def test_specialist_pool_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) SpecialistPoolServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1934,18 +2044,26 @@ def test_specialist_pool_service_auth_adc(): def test_specialist_pool_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( - transport_class + transport_class, ): cred = credentials.AnonymousCredentials() @@ -1955,15 +2073,13 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1978,38 +2094,40 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_host_with_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2017,12 +2135,11 @@ def test_specialist_pool_service_grpc_transport_channel(): def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2031,12 +2148,22 @@ def test_specialist_pool_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2045,7 +2172,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2061,9 +2188,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2077,17 +2202,23 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) -def test_specialist_pool_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2104,9 +2235,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2119,16 +2248,12 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( def test_specialist_pool_service_grpc_lro_client(): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2136,16 +2261,12 @@ def test_specialist_pool_service_grpc_lro_client(): def test_specialist_pool_service_grpc_lro_async_client(): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2156,17 +2277,20 @@ def test_specialist_pool_path(): location = "clam" specialist_pool = "whelk" - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) - actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + actual = SpecialistPoolServiceClient.specialist_pool_path( + project, location, specialist_pool + ) assert expected == actual def test_parse_specialist_pool_path(): expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", - + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", } path = SpecialistPoolServiceClient.specialist_pool_path(**expected) @@ -2174,18 +2298,20 @@ def test_parse_specialist_pool_path(): actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = SpecialistPoolServiceClient.common_billing_account_path(**expected) @@ -2193,18 +2319,18 @@ def test_parse_common_billing_account_path(): actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = SpecialistPoolServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = SpecialistPoolServiceClient.common_folder_path(**expected) @@ -2212,18 +2338,18 @@ def test_parse_common_folder_path(): actual = SpecialistPoolServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = SpecialistPoolServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = SpecialistPoolServiceClient.common_organization_path(**expected) @@ -2231,18 +2357,18 @@ def test_parse_common_organization_path(): actual = SpecialistPoolServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = SpecialistPoolServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = SpecialistPoolServiceClient.common_project_path(**expected) @@ -2250,20 +2376,22 @@ def test_parse_common_project_path(): actual = SpecialistPoolServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = SpecialistPoolServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = SpecialistPoolServiceClient.common_location_path(**expected) @@ -2275,17 +2403,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = SpecialistPoolServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py index 3370e5011e..770c95794f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.vizier_service import VizierServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.vizier_service import ( + VizierServiceAsyncClient, +) from google.cloud.aiplatform_v1beta1.services.vizier_service import VizierServiceClient from google.cloud.aiplatform_v1beta1.services.vizier_service import pagers from google.cloud.aiplatform_v1beta1.services.vizier_service import transports @@ -57,7 +59,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -68,36 +74,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert VizierServiceClient._get_default_mtls_endpoint(None) is None - assert VizierServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert VizierServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert VizierServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert VizierServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert VizierServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + VizierServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - VizierServiceClient, - VizierServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [VizierServiceClient, VizierServiceAsyncClient,] +) def test_vizier_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - VizierServiceClient, - VizierServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [VizierServiceClient, VizierServiceAsyncClient,] +) def test_vizier_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -107,7 +129,7 @@ def test_vizier_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_vizier_service_client_get_transport_class(): @@ -121,29 +143,44 @@ def test_vizier_service_client_get_transport_class(): assert transport == transports.VizierServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), - (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(VizierServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceClient)) -@mock.patch.object(VizierServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceAsyncClient)) -def test_vizier_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + VizierServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceClient), +) +@mock.patch.object( + VizierServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceAsyncClient), +) +def test_vizier_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(VizierServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(VizierServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(VizierServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(VizierServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -159,7 +196,7 @@ def test_vizier_service_client_client_options(client_class, transport_class, tra # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -175,7 +212,7 @@ def test_vizier_service_client_client_options(client_class, transport_class, tra # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -195,13 +232,15 @@ def test_vizier_service_client_client_options(client_class, transport_class, tra client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -214,26 +253,52 @@ def test_vizier_service_client_client_options(client_class, transport_class, tra client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "true"), - (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "false"), - (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(VizierServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceClient)) -@mock.patch.object(VizierServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(VizierServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "true"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "false"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + VizierServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceClient), +) +@mock.patch.object( + VizierServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_vizier_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_vizier_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -256,10 +321,18 @@ def test_vizier_service_client_mtls_env_auto(client_class, transport_class, tran # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -280,9 +353,14 @@ def test_vizier_service_client_mtls_env_auto(client_class, transport_class, tran ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -296,16 +374,23 @@ def test_vizier_service_client_mtls_env_auto(client_class, transport_class, tran ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), - (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_vizier_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_vizier_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -318,16 +403,24 @@ def test_vizier_service_client_client_options_scopes(client_class, transport_cla client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), - (VizierServiceAsyncClient, transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_vizier_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_vizier_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -342,10 +435,12 @@ def test_vizier_service_client_client_options_credentials_file(client_class, tra def test_vizier_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = VizierServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -358,10 +453,11 @@ def test_vizier_service_client_client_options_from_dict(): ) -def test_create_study(transport: str = 'grpc', request_type=vizier_service.CreateStudyRequest): +def test_create_study( + transport: str = "grpc", request_type=vizier_service.CreateStudyRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -369,19 +465,13 @@ def test_create_study(transport: str = 'grpc', request_type=vizier_service.Creat request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_study), - '__call__') as call: + with mock.patch.object(type(client.transport.create_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_study.Study( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=gca_study.Study.State.ACTIVE, - - inactive_reason='inactive_reason_value', - + inactive_reason="inactive_reason_value", ) response = client.create_study(request) @@ -396,13 +486,13 @@ def test_create_study(transport: str = 'grpc', request_type=vizier_service.Creat assert isinstance(response, gca_study.Study) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_study.Study.State.ACTIVE - assert response.inactive_reason == 'inactive_reason_value' + assert response.inactive_reason == "inactive_reason_value" def test_create_study_from_dict(): @@ -413,25 +503,24 @@ def test_create_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_study), - '__call__') as call: + with mock.patch.object(type(client.transport.create_study), "__call__") as call: client.create_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CreateStudyRequest() + @pytest.mark.asyncio -async def test_create_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CreateStudyRequest): +async def test_create_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CreateStudyRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -439,16 +528,16 @@ async def test_create_study_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_study), - '__call__') as call: + with mock.patch.object(type(client.transport.create_study), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_study.Study( - name='name_value', - display_name='display_name_value', - state=gca_study.Study.State.ACTIVE, - inactive_reason='inactive_reason_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_study.Study( + name="name_value", + display_name="display_name_value", + state=gca_study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + ) response = await client.create_study(request) @@ -461,13 +550,13 @@ async def test_create_study_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, gca_study.Study) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_study.Study.State.ACTIVE - assert response.inactive_reason == 'inactive_reason_value' + assert response.inactive_reason == "inactive_reason_value" @pytest.mark.asyncio @@ -476,19 +565,15 @@ async def test_create_study_async_from_dict(): def test_create_study_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateStudyRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_study), - '__call__') as call: + with mock.patch.object(type(client.transport.create_study), "__call__") as call: call.return_value = gca_study.Study() client.create_study(request) @@ -500,27 +585,20 @@ def test_create_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_study_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateStudyRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_study), - '__call__') as call: + with mock.patch.object(type(client.transport.create_study), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_study.Study()) await client.create_study(request) @@ -532,29 +610,21 @@ async def test_create_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_study_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_study), - '__call__') as call: + with mock.patch.object(type(client.transport.create_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_study.Study() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_study( - parent='parent_value', - study=gca_study.Study(name='name_value'), + parent="parent_value", study=gca_study.Study(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -562,36 +632,30 @@ def test_create_study_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].study == gca_study.Study(name='name_value') + assert args[0].study == gca_study.Study(name="name_value") def test_create_study_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_study( vizier_service.CreateStudyRequest(), - parent='parent_value', - study=gca_study.Study(name='name_value'), + parent="parent_value", + study=gca_study.Study(name="name_value"), ) @pytest.mark.asyncio async def test_create_study_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_study), - '__call__') as call: + with mock.patch.object(type(client.transport.create_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_study.Study() @@ -599,8 +663,7 @@ async def test_create_study_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_study( - parent='parent_value', - study=gca_study.Study(name='name_value'), + parent="parent_value", study=gca_study.Study(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -608,31 +671,30 @@ async def test_create_study_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].study == gca_study.Study(name='name_value') + assert args[0].study == gca_study.Study(name="name_value") @pytest.mark.asyncio async def test_create_study_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_study( vizier_service.CreateStudyRequest(), - parent='parent_value', - study=gca_study.Study(name='name_value'), + parent="parent_value", + study=gca_study.Study(name="name_value"), ) -def test_get_study(transport: str = 'grpc', request_type=vizier_service.GetStudyRequest): +def test_get_study( + transport: str = "grpc", request_type=vizier_service.GetStudyRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -640,19 +702,13 @@ def test_get_study(transport: str = 'grpc', request_type=vizier_service.GetStudy request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_study), - '__call__') as call: + with mock.patch.object(type(client.transport.get_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Study( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=study.Study.State.ACTIVE, - - inactive_reason='inactive_reason_value', - + inactive_reason="inactive_reason_value", ) response = client.get_study(request) @@ -667,13 +723,13 @@ def test_get_study(transport: str = 'grpc', request_type=vizier_service.GetStudy assert isinstance(response, study.Study) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == 'inactive_reason_value' + assert response.inactive_reason == "inactive_reason_value" def test_get_study_from_dict(): @@ -684,25 +740,24 @@ def test_get_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_study), - '__call__') as call: + with mock.patch.object(type(client.transport.get_study), "__call__") as call: client.get_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.GetStudyRequest() + @pytest.mark.asyncio -async def test_get_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.GetStudyRequest): +async def test_get_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.GetStudyRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -710,16 +765,16 @@ async def test_get_study_async(transport: str = 'grpc_asyncio', request_type=viz request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_study), - '__call__') as call: + with mock.patch.object(type(client.transport.get_study), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study( - name='name_value', - display_name='display_name_value', - state=study.Study.State.ACTIVE, - inactive_reason='inactive_reason_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Study( + name="name_value", + display_name="display_name_value", + state=study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + ) response = await client.get_study(request) @@ -732,13 +787,13 @@ async def test_get_study_async(transport: str = 'grpc_asyncio', request_type=viz # Establish that the response is the type that we expect. assert isinstance(response, study.Study) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == 'inactive_reason_value' + assert response.inactive_reason == "inactive_reason_value" @pytest.mark.asyncio @@ -747,19 +802,15 @@ async def test_get_study_async_from_dict(): def test_get_study_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetStudyRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_study), - '__call__') as call: + with mock.patch.object(type(client.transport.get_study), "__call__") as call: call.return_value = study.Study() client.get_study(request) @@ -771,27 +822,20 @@ def test_get_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_study_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetStudyRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_study), - '__call__') as call: + with mock.patch.object(type(client.transport.get_study), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) await client.get_study(request) @@ -803,99 +847,79 @@ async def test_get_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_study_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_study), - '__call__') as call: + with mock.patch.object(type(client.transport.get_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Study() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_study( - name='name_value', - ) + client.get_study(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_study_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_study( - vizier_service.GetStudyRequest(), - name='name_value', + vizier_service.GetStudyRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_study_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_study), - '__call__') as call: + with mock.patch.object(type(client.transport.get_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Study() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_study( - name='name_value', - ) + response = await client.get_study(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_study_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_study( - vizier_service.GetStudyRequest(), - name='name_value', + vizier_service.GetStudyRequest(), name="name_value", ) -def test_list_studies(transport: str = 'grpc', request_type=vizier_service.ListStudiesRequest): +def test_list_studies( + transport: str = "grpc", request_type=vizier_service.ListStudiesRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -903,13 +927,10 @@ def test_list_studies(transport: str = 'grpc', request_type=vizier_service.ListS request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListStudiesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_studies(request) @@ -924,7 +945,7 @@ def test_list_studies(transport: str = 'grpc', request_type=vizier_service.ListS assert isinstance(response, pagers.ListStudiesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_studies_from_dict(): @@ -935,25 +956,24 @@ def test_list_studies_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: client.list_studies() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.ListStudiesRequest() + @pytest.mark.asyncio -async def test_list_studies_async(transport: str = 'grpc_asyncio', request_type=vizier_service.ListStudiesRequest): +async def test_list_studies_async( + transport: str = "grpc_asyncio", request_type=vizier_service.ListStudiesRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -961,13 +981,11 @@ async def test_list_studies_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListStudiesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListStudiesResponse(next_page_token="next_page_token_value",) + ) response = await client.list_studies(request) @@ -980,7 +998,7 @@ async def test_list_studies_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListStudiesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -989,19 +1007,15 @@ async def test_list_studies_async_from_dict(): def test_list_studies_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListStudiesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: call.return_value = vizier_service.ListStudiesResponse() client.list_studies(request) @@ -1013,28 +1027,23 @@ def test_list_studies_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_studies_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListStudiesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListStudiesResponse()) + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListStudiesResponse() + ) await client.list_studies(request) @@ -1045,138 +1054,100 @@ async def test_list_studies_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_studies_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListStudiesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_studies( - parent='parent_value', - ) + client.list_studies(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_studies_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_studies( - vizier_service.ListStudiesRequest(), - parent='parent_value', + vizier_service.ListStudiesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_studies_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListStudiesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListStudiesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListStudiesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_studies( - parent='parent_value', - ) + response = await client.list_studies(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_studies_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_studies( - vizier_service.ListStudiesRequest(), - parent='parent_value', + vizier_service.ListStudiesRequest(), parent="parent_value", ) def test_list_studies_pager(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - study.Study(), - ], - next_page_token='abc', + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[], - next_page_token='def', - ), - vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - ], - next_page_token='ghi', + studies=[study.Study(),], next_page_token="ghi", ), vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - ], + studies=[study.Study(), study.Study(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_studies(request={}) @@ -1184,147 +1155,102 @@ def test_list_studies_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, study.Study) - for i in results) + assert all(isinstance(i, study.Study) for i in results) + def test_list_studies_pages(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_studies), - '__call__') as call: + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - study.Study(), - ], - next_page_token='abc', - ), - vizier_service.ListStudiesResponse( - studies=[], - next_page_token='def', + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - ], - next_page_token='ghi', + studies=[study.Study(),], next_page_token="ghi", ), vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - ], + studies=[study.Study(), study.Study(),], ), RuntimeError, ) pages = list(client.list_studies(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_studies_async_pager(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_studies), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_studies), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - study.Study(), - ], - next_page_token='abc', - ), - vizier_service.ListStudiesResponse( - studies=[], - next_page_token='def', + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - ], - next_page_token='ghi', + studies=[study.Study(),], next_page_token="ghi", ), vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - ], + studies=[study.Study(), study.Study(),], ), RuntimeError, ) async_pager = await client.list_studies(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, study.Study) - for i in responses) + assert all(isinstance(i, study.Study) for i in responses) + @pytest.mark.asyncio async def test_list_studies_async_pages(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_studies), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_studies), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - study.Study(), - ], - next_page_token='abc', + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), vizier_service.ListStudiesResponse( - studies=[], - next_page_token='def', + studies=[study.Study(),], next_page_token="ghi", ), vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - ], - next_page_token='ghi', - ), - vizier_service.ListStudiesResponse( - studies=[ - study.Study(), - study.Study(), - ], + studies=[study.Study(), study.Study(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_studies(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_study(transport: str = 'grpc', request_type=vizier_service.DeleteStudyRequest): +def test_delete_study( + transport: str = "grpc", request_type=vizier_service.DeleteStudyRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1332,9 +1258,7 @@ def test_delete_study(transport: str = 'grpc', request_type=vizier_service.Delet request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_study), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1358,25 +1282,24 @@ def test_delete_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_study), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: client.delete_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.DeleteStudyRequest() + @pytest.mark.asyncio -async def test_delete_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.DeleteStudyRequest): +async def test_delete_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.DeleteStudyRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1384,9 +1307,7 @@ async def test_delete_study_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_study), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1408,19 +1329,15 @@ async def test_delete_study_async_from_dict(): def test_delete_study_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteStudyRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_study), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: call.return_value = None client.delete_study(request) @@ -1432,27 +1349,20 @@ def test_delete_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_study_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteStudyRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_study), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.delete_study(request) @@ -1464,99 +1374,79 @@ async def test_delete_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_study_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_study), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_study( - name='name_value', - ) + client.delete_study(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_study_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_study( - vizier_service.DeleteStudyRequest(), - name='name_value', + vizier_service.DeleteStudyRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_study_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_study), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_study( - name='name_value', - ) + response = await client.delete_study(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_study_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_study( - vizier_service.DeleteStudyRequest(), - name='name_value', + vizier_service.DeleteStudyRequest(), name="name_value", ) -def test_lookup_study(transport: str = 'grpc', request_type=vizier_service.LookupStudyRequest): +def test_lookup_study( + transport: str = "grpc", request_type=vizier_service.LookupStudyRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1564,19 +1454,13 @@ def test_lookup_study(transport: str = 'grpc', request_type=vizier_service.Looku request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.lookup_study), - '__call__') as call: + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Study( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=study.Study.State.ACTIVE, - - inactive_reason='inactive_reason_value', - + inactive_reason="inactive_reason_value", ) response = client.lookup_study(request) @@ -1591,13 +1475,13 @@ def test_lookup_study(transport: str = 'grpc', request_type=vizier_service.Looku assert isinstance(response, study.Study) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == 'inactive_reason_value' + assert response.inactive_reason == "inactive_reason_value" def test_lookup_study_from_dict(): @@ -1608,25 +1492,24 @@ def test_lookup_study_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.lookup_study), - '__call__') as call: + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: client.lookup_study() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.LookupStudyRequest() + @pytest.mark.asyncio -async def test_lookup_study_async(transport: str = 'grpc_asyncio', request_type=vizier_service.LookupStudyRequest): +async def test_lookup_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.LookupStudyRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1634,16 +1517,16 @@ async def test_lookup_study_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.lookup_study), - '__call__') as call: + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study( - name='name_value', - display_name='display_name_value', - state=study.Study.State.ACTIVE, - inactive_reason='inactive_reason_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Study( + name="name_value", + display_name="display_name_value", + state=study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + ) response = await client.lookup_study(request) @@ -1656,13 +1539,13 @@ async def test_lookup_study_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, study.Study) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == study.Study.State.ACTIVE - assert response.inactive_reason == 'inactive_reason_value' + assert response.inactive_reason == "inactive_reason_value" @pytest.mark.asyncio @@ -1671,19 +1554,15 @@ async def test_lookup_study_async_from_dict(): def test_lookup_study_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.LookupStudyRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.lookup_study), - '__call__') as call: + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: call.return_value = study.Study() client.lookup_study(request) @@ -1695,27 +1574,20 @@ def test_lookup_study_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_lookup_study_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.LookupStudyRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.lookup_study), - '__call__') as call: + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) await client.lookup_study(request) @@ -1727,99 +1599,79 @@ async def test_lookup_study_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_lookup_study_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.lookup_study), - '__call__') as call: + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Study() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.lookup_study( - parent='parent_value', - ) + client.lookup_study(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_lookup_study_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.lookup_study( - vizier_service.LookupStudyRequest(), - parent='parent_value', + vizier_service.LookupStudyRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_lookup_study_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.lookup_study), - '__call__') as call: + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Study() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.lookup_study( - parent='parent_value', - ) + response = await client.lookup_study(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_lookup_study_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.lookup_study( - vizier_service.LookupStudyRequest(), - parent='parent_value', + vizier_service.LookupStudyRequest(), parent="parent_value", ) -def test_suggest_trials(transport: str = 'grpc', request_type=vizier_service.SuggestTrialsRequest): +def test_suggest_trials( + transport: str = "grpc", request_type=vizier_service.SuggestTrialsRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1827,11 +1679,9 @@ def test_suggest_trials(transport: str = 'grpc', request_type=vizier_service.Sug request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.suggest_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.suggest_trials(request) @@ -1853,25 +1703,24 @@ def test_suggest_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.suggest_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: client.suggest_trials() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.SuggestTrialsRequest() + @pytest.mark.asyncio -async def test_suggest_trials_async(transport: str = 'grpc_asyncio', request_type=vizier_service.SuggestTrialsRequest): +async def test_suggest_trials_async( + transport: str = "grpc_asyncio", request_type=vizier_service.SuggestTrialsRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1879,12 +1728,10 @@ async def test_suggest_trials_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.suggest_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.suggest_trials(request) @@ -1905,20 +1752,16 @@ async def test_suggest_trials_async_from_dict(): def test_suggest_trials_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.SuggestTrialsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.suggest_trials), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.suggest_trials(request) @@ -1929,28 +1772,23 @@ def test_suggest_trials_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_suggest_trials_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.SuggestTrialsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.suggest_trials), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.suggest_trials(request) @@ -1961,16 +1799,14 @@ async def test_suggest_trials_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_create_trial(transport: str = 'grpc', request_type=vizier_service.CreateTrialRequest): +def test_create_trial( + transport: str = "grpc", request_type=vizier_service.CreateTrialRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1978,23 +1814,15 @@ def test_create_trial(transport: str = 'grpc', request_type=vizier_service.Creat request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name='name_value', - - id='id_value', - + name="name_value", + id="id_value", state=study.Trial.State.REQUESTED, - - client_id='client_id_value', - - infeasible_reason='infeasible_reason_value', - - custom_job='custom_job_value', - + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", ) response = client.create_trial(request) @@ -2009,17 +1837,17 @@ def test_create_trial(transport: str = 'grpc', request_type=vizier_service.Creat assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" def test_create_trial_from_dict(): @@ -2030,25 +1858,24 @@ def test_create_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: client.create_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CreateTrialRequest() + @pytest.mark.asyncio -async def test_create_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CreateTrialRequest): +async def test_create_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CreateTrialRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2056,18 +1883,18 @@ async def test_create_trial_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( - name='name_value', - id='id_value', - state=study.Trial.State.REQUESTED, - client_id='client_id_value', - infeasible_reason='infeasible_reason_value', - custom_job='custom_job_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", + ) + ) response = await client.create_trial(request) @@ -2080,17 +1907,17 @@ async def test_create_trial_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" @pytest.mark.asyncio @@ -2099,19 +1926,15 @@ async def test_create_trial_async_from_dict(): def test_create_trial_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateTrialRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: call.return_value = study.Trial() client.create_trial(request) @@ -2123,27 +1946,20 @@ def test_create_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_trial_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CreateTrialRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.create_trial(request) @@ -2155,29 +1971,21 @@ async def test_create_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_trial_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_trial( - parent='parent_value', - trial=study.Trial(name='name_value'), + parent="parent_value", trial=study.Trial(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -2185,36 +1993,30 @@ def test_create_trial_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].trial == study.Trial(name='name_value') + assert args[0].trial == study.Trial(name="name_value") def test_create_trial_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_trial( vizier_service.CreateTrialRequest(), - parent='parent_value', - trial=study.Trial(name='name_value'), + parent="parent_value", + trial=study.Trial(name="name_value"), ) @pytest.mark.asyncio async def test_create_trial_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() @@ -2222,8 +2024,7 @@ async def test_create_trial_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_trial( - parent='parent_value', - trial=study.Trial(name='name_value'), + parent="parent_value", trial=study.Trial(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -2231,31 +2032,30 @@ async def test_create_trial_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].trial == study.Trial(name='name_value') + assert args[0].trial == study.Trial(name="name_value") @pytest.mark.asyncio async def test_create_trial_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_trial( vizier_service.CreateTrialRequest(), - parent='parent_value', - trial=study.Trial(name='name_value'), + parent="parent_value", + trial=study.Trial(name="name_value"), ) -def test_get_trial(transport: str = 'grpc', request_type=vizier_service.GetTrialRequest): +def test_get_trial( + transport: str = "grpc", request_type=vizier_service.GetTrialRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2263,23 +2063,15 @@ def test_get_trial(transport: str = 'grpc', request_type=vizier_service.GetTrial request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name='name_value', - - id='id_value', - + name="name_value", + id="id_value", state=study.Trial.State.REQUESTED, - - client_id='client_id_value', - - infeasible_reason='infeasible_reason_value', - - custom_job='custom_job_value', - + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", ) response = client.get_trial(request) @@ -2294,17 +2086,17 @@ def test_get_trial(transport: str = 'grpc', request_type=vizier_service.GetTrial assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" def test_get_trial_from_dict(): @@ -2315,25 +2107,24 @@ def test_get_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: client.get_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.GetTrialRequest() + @pytest.mark.asyncio -async def test_get_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.GetTrialRequest): +async def test_get_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.GetTrialRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2341,18 +2132,18 @@ async def test_get_trial_async(transport: str = 'grpc_asyncio', request_type=viz request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( - name='name_value', - id='id_value', - state=study.Trial.State.REQUESTED, - client_id='client_id_value', - infeasible_reason='infeasible_reason_value', - custom_job='custom_job_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", + ) + ) response = await client.get_trial(request) @@ -2365,17 +2156,17 @@ async def test_get_trial_async(transport: str = 'grpc_asyncio', request_type=viz # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" @pytest.mark.asyncio @@ -2384,19 +2175,15 @@ async def test_get_trial_async_from_dict(): def test_get_trial_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: call.return_value = study.Trial() client.get_trial(request) @@ -2408,27 +2195,20 @@ def test_get_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_trial_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.GetTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.get_trial(request) @@ -2440,99 +2220,79 @@ async def test_get_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_trial_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_trial( - name='name_value', - ) + client.get_trial(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_trial_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_trial( - vizier_service.GetTrialRequest(), - name='name_value', + vizier_service.GetTrialRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_trial_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_trial( - name='name_value', - ) + response = await client.get_trial(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_trial_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_trial( - vizier_service.GetTrialRequest(), - name='name_value', + vizier_service.GetTrialRequest(), name="name_value", ) -def test_list_trials(transport: str = 'grpc', request_type=vizier_service.ListTrialsRequest): +def test_list_trials( + transport: str = "grpc", request_type=vizier_service.ListTrialsRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2540,13 +2300,10 @@ def test_list_trials(transport: str = 'grpc', request_type=vizier_service.ListTr request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListTrialsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_trials(request) @@ -2561,7 +2318,7 @@ def test_list_trials(transport: str = 'grpc', request_type=vizier_service.ListTr assert isinstance(response, pagers.ListTrialsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_trials_from_dict(): @@ -2572,25 +2329,24 @@ def test_list_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: client.list_trials() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.ListTrialsRequest() + @pytest.mark.asyncio -async def test_list_trials_async(transport: str = 'grpc_asyncio', request_type=vizier_service.ListTrialsRequest): +async def test_list_trials_async( + transport: str = "grpc_asyncio", request_type=vizier_service.ListTrialsRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2598,13 +2354,11 @@ async def test_list_trials_async(transport: str = 'grpc_asyncio', request_type=v request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListTrialsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListTrialsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_trials(request) @@ -2617,7 +2371,7 @@ async def test_list_trials_async(transport: str = 'grpc_asyncio', request_type=v # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrialsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2626,19 +2380,15 @@ async def test_list_trials_async_from_dict(): def test_list_trials_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListTrialsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: call.return_value = vizier_service.ListTrialsResponse() client.list_trials(request) @@ -2650,28 +2400,23 @@ def test_list_trials_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_trials_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListTrialsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListTrialsResponse()) + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListTrialsResponse() + ) await client.list_trials(request) @@ -2682,138 +2427,98 @@ async def test_list_trials_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_trials_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListTrialsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_trials( - parent='parent_value', - ) + client.list_trials(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_trials_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_trials( - vizier_service.ListTrialsRequest(), - parent='parent_value', + vizier_service.ListTrialsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_trials_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListTrialsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListTrialsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListTrialsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_trials( - parent='parent_value', - ) + response = await client.list_trials(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_trials_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_trials( - vizier_service.ListTrialsRequest(), - parent='parent_value', + vizier_service.ListTrialsRequest(), parent="parent_value", ) def test_list_trials_pager(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - study.Trial(), - ], - next_page_token='abc', - ), - vizier_service.ListTrialsResponse( - trials=[], - next_page_token='def', - ), - vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - ], - next_page_token='ghi', + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - ], + trials=[study.Trial(),], next_page_token="ghi", ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_trials(request={}) @@ -2821,147 +2526,96 @@ def test_list_trials_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, study.Trial) - for i in results) + assert all(isinstance(i, study.Trial) for i in results) + def test_list_trials_pages(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_trials), - '__call__') as call: + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - study.Trial(), - ], - next_page_token='abc', - ), - vizier_service.ListTrialsResponse( - trials=[], - next_page_token='def', - ), - vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - ], - next_page_token='ghi', + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - ], + trials=[study.Trial(),], next_page_token="ghi", ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) pages = list(client.list_trials(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_trials_async_pager(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_trials), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_trials), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - study.Trial(), - ], - next_page_token='abc', - ), - vizier_service.ListTrialsResponse( - trials=[], - next_page_token='def', - ), - vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - ], - next_page_token='ghi', + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - ], + trials=[study.Trial(),], next_page_token="ghi", ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) async_pager = await client.list_trials(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, study.Trial) - for i in responses) + assert all(isinstance(i, study.Trial) for i in responses) + @pytest.mark.asyncio async def test_list_trials_async_pages(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_trials), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_trials), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - study.Trial(), - ], - next_page_token='abc', - ), - vizier_service.ListTrialsResponse( - trials=[], - next_page_token='def', - ), - vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - ], - next_page_token='ghi', + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), vizier_service.ListTrialsResponse( - trials=[ - study.Trial(), - study.Trial(), - ], + trials=[study.Trial(),], next_page_token="ghi", ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_trials(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_add_trial_measurement(transport: str = 'grpc', request_type=vizier_service.AddTrialMeasurementRequest): +def test_add_trial_measurement( + transport: str = "grpc", request_type=vizier_service.AddTrialMeasurementRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2970,22 +2624,16 @@ def test_add_trial_measurement(transport: str = 'grpc', request_type=vizier_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), - '__call__') as call: + type(client.transport.add_trial_measurement), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name='name_value', - - id='id_value', - + name="name_value", + id="id_value", state=study.Trial.State.REQUESTED, - - client_id='client_id_value', - - infeasible_reason='infeasible_reason_value', - - custom_job='custom_job_value', - + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", ) response = client.add_trial_measurement(request) @@ -3000,17 +2648,17 @@ def test_add_trial_measurement(transport: str = 'grpc', request_type=vizier_serv assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" def test_add_trial_measurement_from_dict(): @@ -3021,25 +2669,27 @@ def test_add_trial_measurement_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), - '__call__') as call: + type(client.transport.add_trial_measurement), "__call__" + ) as call: client.add_trial_measurement() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.AddTrialMeasurementRequest() + @pytest.mark.asyncio -async def test_add_trial_measurement_async(transport: str = 'grpc_asyncio', request_type=vizier_service.AddTrialMeasurementRequest): +async def test_add_trial_measurement_async( + transport: str = "grpc_asyncio", + request_type=vizier_service.AddTrialMeasurementRequest, +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3048,17 +2698,19 @@ async def test_add_trial_measurement_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), - '__call__') as call: + type(client.transport.add_trial_measurement), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( - name='name_value', - id='id_value', - state=study.Trial.State.REQUESTED, - client_id='client_id_value', - infeasible_reason='infeasible_reason_value', - custom_job='custom_job_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", + ) + ) response = await client.add_trial_measurement(request) @@ -3071,17 +2723,17 @@ async def test_add_trial_measurement_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" @pytest.mark.asyncio @@ -3090,19 +2742,17 @@ async def test_add_trial_measurement_async_from_dict(): def test_add_trial_measurement_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.AddTrialMeasurementRequest() - request.trial_name = 'trial_name/value' + request.trial_name = "trial_name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), - '__call__') as call: + type(client.transport.add_trial_measurement), "__call__" + ) as call: call.return_value = study.Trial() client.add_trial_measurement(request) @@ -3114,27 +2764,22 @@ def test_add_trial_measurement_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'trial_name=trial_name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_add_trial_measurement_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.AddTrialMeasurementRequest() - request.trial_name = 'trial_name/value' + request.trial_name = "trial_name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_trial_measurement), - '__call__') as call: + type(client.transport.add_trial_measurement), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.add_trial_measurement(request) @@ -3146,16 +2791,14 @@ async def test_add_trial_measurement_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'trial_name=trial_name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] -def test_complete_trial(transport: str = 'grpc', request_type=vizier_service.CompleteTrialRequest): +def test_complete_trial( + transport: str = "grpc", request_type=vizier_service.CompleteTrialRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3163,23 +2806,15 @@ def test_complete_trial(transport: str = 'grpc', request_type=vizier_service.Com request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.complete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name='name_value', - - id='id_value', - + name="name_value", + id="id_value", state=study.Trial.State.REQUESTED, - - client_id='client_id_value', - - infeasible_reason='infeasible_reason_value', - - custom_job='custom_job_value', - + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", ) response = client.complete_trial(request) @@ -3194,17 +2829,17 @@ def test_complete_trial(transport: str = 'grpc', request_type=vizier_service.Com assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" def test_complete_trial_from_dict(): @@ -3215,25 +2850,24 @@ def test_complete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.complete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: client.complete_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CompleteTrialRequest() + @pytest.mark.asyncio -async def test_complete_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CompleteTrialRequest): +async def test_complete_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CompleteTrialRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3241,18 +2875,18 @@ async def test_complete_trial_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.complete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( - name='name_value', - id='id_value', - state=study.Trial.State.REQUESTED, - client_id='client_id_value', - infeasible_reason='infeasible_reason_value', - custom_job='custom_job_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", + ) + ) response = await client.complete_trial(request) @@ -3265,17 +2899,17 @@ async def test_complete_trial_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" @pytest.mark.asyncio @@ -3284,19 +2918,15 @@ async def test_complete_trial_async_from_dict(): def test_complete_trial_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CompleteTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.complete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: call.return_value = study.Trial() client.complete_trial(request) @@ -3308,27 +2938,20 @@ def test_complete_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_complete_trial_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CompleteTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.complete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.complete_trial(request) @@ -3340,16 +2963,14 @@ async def test_complete_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_delete_trial(transport: str = 'grpc', request_type=vizier_service.DeleteTrialRequest): +def test_delete_trial( + transport: str = "grpc", request_type=vizier_service.DeleteTrialRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3357,9 +2978,7 @@ def test_delete_trial(transport: str = 'grpc', request_type=vizier_service.Delet request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -3383,25 +3002,24 @@ def test_delete_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: client.delete_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.DeleteTrialRequest() + @pytest.mark.asyncio -async def test_delete_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.DeleteTrialRequest): +async def test_delete_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.DeleteTrialRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3409,9 +3027,7 @@ async def test_delete_trial_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -3433,19 +3049,15 @@ async def test_delete_trial_async_from_dict(): def test_delete_trial_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: call.return_value = None client.delete_trial(request) @@ -3457,27 +3069,20 @@ def test_delete_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_trial_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.DeleteTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.delete_trial(request) @@ -3489,99 +3094,80 @@ async def test_delete_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_trial_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_trial( - name='name_value', - ) + client.delete_trial(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_trial_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_trial( - vizier_service.DeleteTrialRequest(), - name='name_value', + vizier_service.DeleteTrialRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_trial_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_trial( - name='name_value', - ) + response = await client.delete_trial(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_trial_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_trial( - vizier_service.DeleteTrialRequest(), - name='name_value', + vizier_service.DeleteTrialRequest(), name="name_value", ) -def test_check_trial_early_stopping_state(transport: str = 'grpc', request_type=vizier_service.CheckTrialEarlyStoppingStateRequest): +def test_check_trial_early_stopping_state( + transport: str = "grpc", + request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3590,10 +3176,10 @@ def test_check_trial_early_stopping_state(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), - '__call__') as call: + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.check_trial_early_stopping_state(request) @@ -3615,25 +3201,27 @@ def test_check_trial_early_stopping_state_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), - '__call__') as call: + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: client.check_trial_early_stopping_state() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.CheckTrialEarlyStoppingStateRequest() + @pytest.mark.asyncio -async def test_check_trial_early_stopping_state_async(transport: str = 'grpc_asyncio', request_type=vizier_service.CheckTrialEarlyStoppingStateRequest): +async def test_check_trial_early_stopping_state_async( + transport: str = "grpc_asyncio", + request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3642,11 +3230,11 @@ async def test_check_trial_early_stopping_state_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), - '__call__') as call: + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.check_trial_early_stopping_state(request) @@ -3667,20 +3255,18 @@ async def test_check_trial_early_stopping_state_async_from_dict(): def test_check_trial_early_stopping_state_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CheckTrialEarlyStoppingStateRequest() - request.trial_name = 'trial_name/value' + request.trial_name = "trial_name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.check_trial_early_stopping_state(request) @@ -3691,28 +3277,25 @@ def test_check_trial_early_stopping_state_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'trial_name=trial_name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_check_trial_early_stopping_state_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.CheckTrialEarlyStoppingStateRequest() - request.trial_name = 'trial_name/value' + request.trial_name = "trial_name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.check_trial_early_stopping_state), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.check_trial_early_stopping_state(request) @@ -3723,16 +3306,14 @@ async def test_check_trial_early_stopping_state_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'trial_name=trial_name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] -def test_stop_trial(transport: str = 'grpc', request_type=vizier_service.StopTrialRequest): +def test_stop_trial( + transport: str = "grpc", request_type=vizier_service.StopTrialRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3740,23 +3321,15 @@ def test_stop_trial(transport: str = 'grpc', request_type=vizier_service.StopTri request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.stop_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = study.Trial( - name='name_value', - - id='id_value', - + name="name_value", + id="id_value", state=study.Trial.State.REQUESTED, - - client_id='client_id_value', - - infeasible_reason='infeasible_reason_value', - - custom_job='custom_job_value', - + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", ) response = client.stop_trial(request) @@ -3771,17 +3344,17 @@ def test_stop_trial(transport: str = 'grpc', request_type=vizier_service.StopTri assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" def test_stop_trial_from_dict(): @@ -3792,25 +3365,24 @@ def test_stop_trial_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.stop_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: client.stop_trial() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.StopTrialRequest() + @pytest.mark.asyncio -async def test_stop_trial_async(transport: str = 'grpc_asyncio', request_type=vizier_service.StopTrialRequest): +async def test_stop_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.StopTrialRequest +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3818,18 +3390,18 @@ async def test_stop_trial_async(transport: str = 'grpc_asyncio', request_type=vi request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.stop_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial( - name='name_value', - id='id_value', - state=study.Trial.State.REQUESTED, - client_id='client_id_value', - infeasible_reason='infeasible_reason_value', - custom_job='custom_job_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + client_id="client_id_value", + infeasible_reason="infeasible_reason_value", + custom_job="custom_job_value", + ) + ) response = await client.stop_trial(request) @@ -3842,17 +3414,17 @@ async def test_stop_trial_async(transport: str = 'grpc_asyncio', request_type=vi # Establish that the response is the type that we expect. assert isinstance(response, study.Trial) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.id == 'id_value' + assert response.id == "id_value" assert response.state == study.Trial.State.REQUESTED - assert response.client_id == 'client_id_value' + assert response.client_id == "client_id_value" - assert response.infeasible_reason == 'infeasible_reason_value' + assert response.infeasible_reason == "infeasible_reason_value" - assert response.custom_job == 'custom_job_value' + assert response.custom_job == "custom_job_value" @pytest.mark.asyncio @@ -3861,19 +3433,15 @@ async def test_stop_trial_async_from_dict(): def test_stop_trial_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.StopTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.stop_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: call.return_value = study.Trial() client.stop_trial(request) @@ -3885,27 +3453,20 @@ def test_stop_trial_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_stop_trial_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.StopTrialRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.stop_trial), - '__call__') as call: + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) await client.stop_trial(request) @@ -3917,16 +3478,14 @@ async def test_stop_trial_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_list_optimal_trials(transport: str = 'grpc', request_type=vizier_service.ListOptimalTrialsRequest): +def test_list_optimal_trials( + transport: str = "grpc", request_type=vizier_service.ListOptimalTrialsRequest +): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3935,11 +3494,10 @@ def test_list_optimal_trials(transport: str = 'grpc', request_type=vizier_servic # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), - '__call__') as call: + type(client.transport.list_optimal_trials), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = vizier_service.ListOptimalTrialsResponse( - ) + call.return_value = vizier_service.ListOptimalTrialsResponse() response = client.list_optimal_trials(request) @@ -3962,25 +3520,27 @@ def test_list_optimal_trials_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), - '__call__') as call: + type(client.transport.list_optimal_trials), "__call__" + ) as call: client.list_optimal_trials() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == vizier_service.ListOptimalTrialsRequest() + @pytest.mark.asyncio -async def test_list_optimal_trials_async(transport: str = 'grpc_asyncio', request_type=vizier_service.ListOptimalTrialsRequest): +async def test_list_optimal_trials_async( + transport: str = "grpc_asyncio", + request_type=vizier_service.ListOptimalTrialsRequest, +): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3989,11 +3549,12 @@ async def test_list_optimal_trials_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), - '__call__') as call: + type(client.transport.list_optimal_trials), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListOptimalTrialsResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListOptimalTrialsResponse() + ) response = await client.list_optimal_trials(request) @@ -4013,19 +3574,17 @@ async def test_list_optimal_trials_async_from_dict(): def test_list_optimal_trials_field_headers(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListOptimalTrialsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), - '__call__') as call: + type(client.transport.list_optimal_trials), "__call__" + ) as call: call.return_value = vizier_service.ListOptimalTrialsResponse() client.list_optimal_trials(request) @@ -4037,28 +3596,25 @@ def test_list_optimal_trials_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_optimal_trials_field_headers_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = vizier_service.ListOptimalTrialsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListOptimalTrialsResponse()) + type(client.transport.list_optimal_trials), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListOptimalTrialsResponse() + ) await client.list_optimal_trials(request) @@ -4069,92 +3625,77 @@ async def test_list_optimal_trials_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_optimal_trials_flattened(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), - '__call__') as call: + type(client.transport.list_optimal_trials), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListOptimalTrialsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_optimal_trials( - parent='parent_value', - ) + client.list_optimal_trials(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_optimal_trials_flattened_error(): - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_optimal_trials( - vizier_service.ListOptimalTrialsRequest(), - parent='parent_value', + vizier_service.ListOptimalTrialsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_optimal_trials_flattened_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_optimal_trials), - '__call__') as call: + type(client.transport.list_optimal_trials), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = vizier_service.ListOptimalTrialsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(vizier_service.ListOptimalTrialsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListOptimalTrialsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_optimal_trials( - parent='parent_value', - ) + response = await client.list_optimal_trials(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_optimal_trials_flattened_error_async(): - client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_optimal_trials( - vizier_service.ListOptimalTrialsRequest(), - parent='parent_value', + vizier_service.ListOptimalTrialsRequest(), parent="parent_value", ) @@ -4165,8 +3706,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -4185,8 +3725,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = VizierServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -4214,13 +3753,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.VizierServiceGrpcTransport, - transports.VizierServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -4228,13 +3770,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.VizierServiceGrpcTransport, - ) + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.VizierServiceGrpcTransport,) def test_vizier_service_base_transport_error(): @@ -4242,13 +3779,15 @@ def test_vizier_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.VizierServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_vizier_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.VizierServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -4257,22 +3796,22 @@ def test_vizier_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_study', - 'get_study', - 'list_studies', - 'delete_study', - 'lookup_study', - 'suggest_trials', - 'create_trial', - 'get_trial', - 'list_trials', - 'add_trial_measurement', - 'complete_trial', - 'delete_trial', - 'check_trial_early_stopping_state', - 'stop_trial', - 'list_optimal_trials', - ) + "create_study", + "get_study", + "list_studies", + "delete_study", + "lookup_study", + "suggest_trials", + "create_trial", + "get_trial", + "list_trials", + "add_trial_measurement", + "complete_trial", + "delete_trial", + "check_trial_early_stopping_state", + "stop_trial", + "list_optimal_trials", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -4285,23 +3824,28 @@ def test_vizier_service_base_transport(): def test_vizier_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.VizierServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_vizier_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.VizierServiceTransport() @@ -4310,11 +3854,11 @@ def test_vizier_service_base_transport_with_adc(): def test_vizier_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) VizierServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -4322,19 +3866,25 @@ def test_vizier_service_auth_adc(): def test_vizier_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.VizierServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.VizierServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.VizierServiceGrpcTransport, transports.VizierServiceGrpcAsyncIOTransport]) -def test_vizier_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) +def test_vizier_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -4343,15 +3893,13 @@ def test_vizier_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -4366,38 +3914,40 @@ def test_vizier_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_vizier_service_host_no_port(): client = VizierServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_vizier_service_host_with_port(): client = VizierServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_vizier_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.VizierServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -4405,12 +3955,11 @@ def test_vizier_service_grpc_transport_channel(): def test_vizier_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.VizierServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -4419,12 +3968,20 @@ def test_vizier_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.VizierServiceGrpcTransport, transports.VizierServiceGrpcAsyncIOTransport]) -def test_vizier_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) +def test_vizier_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -4433,7 +3990,7 @@ def test_vizier_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -4449,9 +4006,7 @@ def test_vizier_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -4465,17 +4020,23 @@ def test_vizier_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.VizierServiceGrpcTransport, transports.VizierServiceGrpcAsyncIOTransport]) -def test_vizier_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) +def test_vizier_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -4492,9 +4053,7 @@ def test_vizier_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -4507,16 +4066,12 @@ def test_vizier_service_transport_channel_mtls_with_adc( def test_vizier_service_grpc_lro_client(): client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -4524,16 +4079,12 @@ def test_vizier_service_grpc_lro_client(): def test_vizier_service_grpc_lro_async_client(): client = VizierServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -4544,17 +4095,18 @@ def test_custom_job_path(): location = "clam" custom_job = "whelk" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) actual = VizierServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "custom_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "custom_job": "nudibranch", } path = VizierServiceClient.custom_job_path(**expected) @@ -4562,22 +4114,24 @@ def test_parse_custom_job_path(): actual = VizierServiceClient.parse_custom_job_path(path) assert expected == actual + def test_study_path(): project = "cuttlefish" location = "mussel" study = "winkle" - expected = "projects/{project}/locations/{location}/studies/{study}".format(project=project, location=location, study=study, ) + expected = "projects/{project}/locations/{location}/studies/{study}".format( + project=project, location=location, study=study, + ) actual = VizierServiceClient.study_path(project, location, study) assert expected == actual def test_parse_study_path(): expected = { - "project": "nautilus", - "location": "scallop", - "study": "abalone", - + "project": "nautilus", + "location": "scallop", + "study": "abalone", } path = VizierServiceClient.study_path(**expected) @@ -4585,24 +4139,26 @@ def test_parse_study_path(): actual = VizierServiceClient.parse_study_path(path) assert expected == actual + def test_trial_path(): project = "squid" location = "clam" study = "whelk" trial = "octopus" - expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) actual = VizierServiceClient.trial_path(project, location, study, trial) assert expected == actual def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", - + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", } path = VizierServiceClient.trial_path(**expected) @@ -4610,18 +4166,20 @@ def test_parse_trial_path(): actual = VizierServiceClient.parse_trial_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = VizierServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } path = VizierServiceClient.common_billing_account_path(**expected) @@ -4629,18 +4187,18 @@ def test_parse_common_billing_account_path(): actual = VizierServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = VizierServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = VizierServiceClient.common_folder_path(**expected) @@ -4648,18 +4206,18 @@ def test_parse_common_folder_path(): actual = VizierServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = VizierServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = VizierServiceClient.common_organization_path(**expected) @@ -4667,18 +4225,18 @@ def test_parse_common_organization_path(): actual = VizierServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = VizierServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = VizierServiceClient.common_project_path(**expected) @@ -4686,20 +4244,22 @@ def test_parse_common_project_path(): actual = VizierServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = VizierServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = VizierServiceClient.common_location_path(**expected) @@ -4711,17 +4271,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.VizierServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.VizierServiceTransport, "_prep_wrapped_messages" + ) as prep: client = VizierServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.VizierServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.VizierServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = VizierServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) From 74c7969c62996295c4a106f3fa1e845bb6bf361a Mon Sep 17 00:00:00 2001 From: "Ignacio (Nacho) Cano" Date: Thu, 8 Apr 2021 16:42:11 -0700 Subject: [PATCH 04/36] feat: adding MetadataStore class (#293) * Adding MetadataStore class with its create method plus UTs. Cosmetic fix to allow string ids while validating ids. * feat: adding MetadataStore class with its create method plus UTs. fix: Allow string ids while validating ids. * fix: Adding import although lint complaints. * fix: changes after code review * fix: lint * fix: lint --- google/cloud/aiplatform/compat/__init__.py | 2 + .../aiplatform/compat/services/__init__.py | 4 + .../cloud/aiplatform/compat/types/__init__.py | 2 + google/cloud/aiplatform/metadata/__init__.py | 16 ++ .../aiplatform/metadata/metadata_store.py | 138 +++++++++++ google/cloud/aiplatform/utils.py | 16 +- tests/unit/aiplatform/test_metadata.py | 230 ++++++++++++++++++ tests/unit/aiplatform/test_utils.py | 3 +- 8 files changed, 407 insertions(+), 4 deletions(-) create mode 100644 google/cloud/aiplatform/metadata/__init__.py create mode 100644 google/cloud/aiplatform/metadata/metadata_store.py create mode 100644 tests/unit/aiplatform/test_metadata.py diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index 36d805c6cb..16cc83a9cd 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -34,6 +34,7 @@ services.specialist_pool_service_client = ( services.specialist_pool_service_client_v1beta1 ) + services.metadata_service_client = services.metadata_service_client_v1beta1 types.accelerator_type = types.accelerator_type_v1beta1 types.annotation = types.annotation_v1beta1 @@ -69,6 +70,7 @@ types.specialist_pool = types.specialist_pool_v1beta1 types.specialist_pool_service = types.specialist_pool_service_v1beta1 types.training_pipeline = types.training_pipeline_v1beta1 + types.metadata_service = types.metadata_service_v1beta1 if DEFAULT_VERSION == V1: diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py index 0888c27fbb..8cbe922cbf 100644 --- a/google/cloud/aiplatform/compat/services/__init__.py +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -36,6 +36,9 @@ from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( client as specialist_pool_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.metadata_service import ( + client as metadata_service_client_v1beta1, +) from google.cloud.aiplatform_v1.services.dataset_service import ( client as dataset_service_client_v1, @@ -76,4 +79,5 @@ pipeline_service_client_v1beta1, prediction_service_client_v1beta1, specialist_pool_service_client_v1beta1, + metadata_service_client_v1beta1, ) diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index d03e0d2f3a..047f1dee1d 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -50,6 +50,7 @@ specialist_pool as specialist_pool_v1beta1, specialist_pool_service as specialist_pool_service_v1beta1, training_pipeline as training_pipeline_v1beta1, + metadata_service as metadata_service_v1beta1, ) from google.cloud.aiplatform_v1.types import ( accelerator_type as accelerator_type_v1, @@ -155,4 +156,5 @@ specialist_pool_v1beta1, specialist_pool_service_v1beta1, training_pipeline_v1beta1, + metadata_service_v1beta1, ) diff --git a/google/cloud/aiplatform/metadata/__init__.py b/google/cloud/aiplatform/metadata/__init__.py new file mode 100644 index 0000000000..2144d2e268 --- /dev/null +++ b/google/cloud/aiplatform/metadata/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/aiplatform/metadata/metadata_store.py b/google/cloud/aiplatform/metadata/metadata_store.py new file mode 100644 index 0000000000..3c187813a1 --- /dev/null +++ b/google/cloud/aiplatform/metadata/metadata_store.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional + +import logging +from google.auth import credentials as auth_credentials +from google.api_core import exceptions + +from google.cloud.aiplatform import base, initializer +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import utils +from google.cloud.aiplatform_v1beta1.types import metadata_store as gca_metadata_store + + +class _MetadataStore(base.AiPlatformResourceNounWithFutureManager): + """Managed MetadataStore resource for AI Platform""" + + client_class = utils.MetadataClientWithOverride + _is_client_prediction_client = False + _resource_noun = "metadataStores" + _getter_method = "get_metadata_store" + _delete_method = "delete_metadata_store" + + def __init__( + self, + metadata_store_name: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing MetadataStore given a MetadataStore name or ID. + + Args: + metadata_store_name (str): + Optional. A fully-qualified MetadataStore resource name or metadataStore ID. + Example: "projects/123/locations/us-central1/metadataStores/my-store" or + "my-store" when project and location are initialized or passed. + If not set, metadata_store_name will be set to "default". + project (str): + Optional project to retrieve resource from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve resource from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + """ + + super().__init__( + project=project, location=location, credentials=credentials, + ) + self._gca_resource = self._get_gca_resource(resource_name=metadata_store_name) + + @classmethod + def create( + cls, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + ) -> "_MetadataStore": + """Creates a new MetadataStore if it does not exist. + + Args: + metadata_store_id (str): + The {metadatastore} portion of the resource name with + the format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the metadata store. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + + + Returns: + metadata_store (MetadataStore): + Instantiated representation of the managed metadata store resource. + + """ + api_client = cls._instantiate_client(location=location, credentials=credentials) + gapic_metadata_store = gca_metadata_store.MetadataStore( + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + select_version=compat.V1BETA1, + ) + ) + + try: + api_client.create_metadata_store( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + metadata_store=gapic_metadata_store, + metadata_store_id=metadata_store_id, + ).result() + except exceptions.AlreadyExists: + logging.info("MetadataStore %s already exists" % metadata_store_id) + + return cls( + metadata_store_name=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index ec39038942..9450d3f425 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -37,6 +37,7 @@ model_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, + metadata_service_client_v1beta1, ) from google.cloud.aiplatform.compat.services import ( dataset_service_client_v1, @@ -60,6 +61,7 @@ prediction_service_client_v1beta1.PredictionServiceClient, pipeline_service_client_v1beta1.PipelineServiceClient, job_service_client_v1beta1.JobServiceClient, + metadata_service_client_v1beta1.MetadataServiceClient, # v1 dataset_service_client_v1.DatasetServiceClient, endpoint_service_client_v1.EndpointServiceClient, @@ -69,12 +71,11 @@ job_service_client_v1.JobServiceClient, ) -# TODO(b/170334193): Add support for resource names with non-integer IDs # TODO(b/170334098): Add support for resource names more than one level deep RESOURCE_NAME_PATTERN = re.compile( - r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P\w+)\/(?P\d+)$" + r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P\w+)\/(?P[\w-]+)$" ) -RESOURCE_ID_PATTERN = re.compile(r"^\d+$") +RESOURCE_ID_PATTERN = re.compile(r"^[\w-]+$") Fields = namedtuple("Fields", ["project", "location", "resource", "id"],) @@ -454,6 +455,14 @@ class PredictionClientWithOverride(ClientWithOverride): ) +class MetadataClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.V1BETA1 + _version_map = ( + (compat.V1BETA1, metadata_service_client_v1beta1.MetadataServiceClient), + ) + + AiPlatformServiceClientWithOverride = TypeVar( "AiPlatformServiceClientWithOverride", DatasetClientWithOverride, @@ -462,6 +471,7 @@ class PredictionClientWithOverride(ClientWithOverride): ModelClientWithOverride, PipelineClientWithOverride, PredictionClientWithOverride, + MetadataClientWithOverride, ) diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py new file mode 100644 index 0000000000..8da3539380 --- /dev/null +++ b/tests/unit/aiplatform/test_metadata.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import pytest + +from unittest import mock +from importlib import reload +from unittest.mock import patch + +from google.api_core import operation +from google.auth.exceptions import GoogleAuthError +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform +from google.cloud.aiplatform.metadata import metadata_store +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform_v1beta1 import MetadataServiceClient +from google.cloud.aiplatform_v1beta1 import MetadataStore as GapicMetadataStore +from google.cloud.aiplatform_v1beta1.types import metadata_service +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_ALT_LOCATION = "europe-west4" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +# metadata_store +_TEST_ID = "test-id" +_TEST_DEFAULT_ID = "default" + +_TEST_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/{_TEST_ID}" +) +_TEST_ALT_LOC_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_ALT_LOCATION}/metadataStores/{_TEST_ID}" +) +_TEST_DEFAULT_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/{_TEST_DEFAULT_ID}" + +_TEST_INVALID_NAME = f"prj/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/{_TEST_ID}" + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def get_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + name=_TEST_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def get_default_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + name=_TEST_DEFAULT_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def get_metadata_store_without_name_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def create_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "create_metadata_store" + ) as create_metadata_store_mock: + create_metadata_store_lro_mock = mock.Mock(operation.Operation) + create_metadata_store_lro_mock.result.return_value = GapicMetadataStore( + name=_TEST_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_metadata_store_mock.return_value = create_metadata_store_lro_mock + yield create_metadata_store_mock + + +@pytest.fixture +def create_default_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "create_metadata_store" + ) as create_metadata_store_mock: + create_metadata_store_lro_mock = mock.Mock(operation.Operation) + create_metadata_store_lro_mock.result.return_value = GapicMetadataStore( + name=_TEST_DEFAULT_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_metadata_store_mock.return_value = create_metadata_store_lro_mock + yield create_metadata_store_mock + + +@pytest.fixture +def delete_metadata_store_mock(): + with mock.patch.object( + MetadataServiceClient, "delete_metadata_store" + ) as delete_metadata_store_mock: + delete_metadata_store_lro_mock = mock.Mock(operation.Operation) + delete_metadata_store_lro_mock.result.return_value = ( + metadata_service.DeleteMetadataStoreRequest() + ) + delete_metadata_store_mock.return_value = delete_metadata_store_lro_mock + yield delete_metadata_store_mock + + +class TestMetadataStore: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_metadata_store(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT) + metadata_store._MetadataStore(metadata_store_name=_TEST_NAME) + get_metadata_store_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_metadata_store_with_id(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore(metadata_store_name=_TEST_ID) + get_metadata_store_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_metadata_store_with_default_id(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore() + get_metadata_store_mock.assert_called_once_with(name=_TEST_DEFAULT_NAME) + + @pytest.mark.usefixtures("get_metadata_store_without_name_mock") + @patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "", "GOOGLE_APPLICATION_CREDENTIALS": ""} + ) + def test_init_metadata_store_with_id_without_project_or_location(self): + with pytest.raises(GoogleAuthError): + metadata_store._MetadataStore( + metadata_store_name=_TEST_ID, + credentials=auth_credentials.AnonymousCredentials(), + ) + + def test_init_metadata_store_with_location_override(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore( + metadata_store_name=_TEST_ID, location=_TEST_ALT_LOCATION + ) + get_metadata_store_mock.assert_called_once_with(name=_TEST_ALT_LOC_NAME) + + @pytest.mark.usefixtures("get_metadata_store_mock") + def test_init_metadata_store_with_invalid_name(self): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore(metadata_store_name=_TEST_INVALID_NAME) + + @pytest.mark.usefixtures("get_default_metadata_store_mock") + def test_init_aiplatform_with_encryption_key_name_and_create_default_metadata_store( + self, create_default_metadata_store_mock + ): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_metadata_store = metadata_store._MetadataStore.create( + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + expected_metadata_store = GapicMetadataStore( + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_default_metadata_store_mock.assert_called_once_with( + parent=_TEST_PARENT, + metadata_store_id=_TEST_DEFAULT_ID, + metadata_store=expected_metadata_store, + ) + + expected_metadata_store.name = _TEST_DEFAULT_NAME + assert my_metadata_store._gca_resource == expected_metadata_store + + @pytest.mark.usefixtures("get_metadata_store_mock") + def test_create_non_default_metadata_store(self, create_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_metadata_store = metadata_store._MetadataStore.create( + metadata_store_id=_TEST_ID, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + expected_metadata_store = GapicMetadataStore( + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_metadata_store_mock.assert_called_once_with( + parent=_TEST_PARENT, + metadata_store_id=_TEST_ID, + metadata_store=expected_metadata_store, + ) + + expected_metadata_store.name = _TEST_NAME + assert my_metadata_store._gca_resource == expected_metadata_store diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 3032475069..03ca7cd6fe 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -44,7 +44,8 @@ ("projects/123456/locations/us-central1/datasets/987654", True), ("projects/857392/locations/us-central1/trainingPipelines/347292", True), ("projects/acme-co-proj-1/locations/us-central1/datasets/123456", True), - ("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", False), + ("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", True), + ("projects/acme-co-proj-1/locations/us-central1/datasets/abc-def", True), ("project/123456/locations/us-central1/datasets/987654", False), ("project//locations//datasets/987654", False), ("locations/europe-west4/datasets/987654", False), From 0a9f964e1cf368726c886e187aeb78391dbfe9f5 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Mon, 12 Apr 2021 16:14:30 -0400 Subject: [PATCH 05/36] test --- .github/header-checker-lint.yml | 2 +- .github/sync-repo-settings.yaml | 1 - .pre-commit-config.yaml | 14 + .../featurestore_online_serving_service.rst | 6 + .../featurestore_service.rst | 11 + docs/aiplatform_v1beta1/services.rst | 2 + docs/conf.py | 13 + .../services/dataset_service/async_client.py | 2 +- .../services/dataset_service/client.py | 18 +- .../services/endpoint_service/async_client.py | 2 +- .../services/endpoint_service/client.py | 18 +- .../services/job_service/async_client.py | 2 +- .../services/job_service/client.py | 18 +- .../services/migration_service/client.py | 12 +- .../services/model_service/async_client.py | 2 +- .../services/model_service/client.py | 14 +- .../services/pipeline_service/async_client.py | 2 +- .../services/pipeline_service/client.py | 6 +- .../specialist_pool_service/async_client.py | 2 +- .../specialist_pool_service/client.py | 14 +- google/cloud/aiplatform_v1beta1/__init__.py | 114 + .../services/dataset_service/async_client.py | 2 +- .../services/dataset_service/client.py | 18 +- .../services/endpoint_service/async_client.py | 2 +- .../services/endpoint_service/client.py | 18 +- .../__init__.py | 24 + .../async_client.py | 339 + .../client.py | 513 ++ .../transports/__init__.py | 35 + .../transports/base.py | 141 + .../transports/grpc.py | 279 + .../transports/grpc_asyncio.py | 284 + .../services/featurestore_service/__init__.py | 24 + .../featurestore_service/async_client.py | 2019 ++++++ .../services/featurestore_service/client.py | 2227 ++++++ .../services/featurestore_service/pagers.py | 511 ++ .../transports/__init__.py | 35 + .../featurestore_service/transports/base.py | 391 + .../featurestore_service/transports/grpc.py | 772 ++ .../transports/grpc_asyncio.py | 777 ++ .../services/job_service/async_client.py | 2 +- .../services/job_service/client.py | 26 +- .../services/metadata_service/async_client.py | 2 +- .../services/metadata_service/client.py | 14 +- .../services/model_service/async_client.py | 2 +- .../services/model_service/client.py | 14 +- .../services/pipeline_service/async_client.py | 2 +- .../services/pipeline_service/client.py | 6 +- .../specialist_pool_service/async_client.py | 2 +- .../specialist_pool_service/client.py | 14 +- .../aiplatform_v1beta1/types/__init__.py | 126 + .../aiplatform_v1beta1/types/entity_type.py | 103 + .../cloud/aiplatform_v1beta1/types/feature.py | 134 + .../types/feature_selector.py | 59 + .../aiplatform_v1beta1/types/featurestore.py | 135 + .../types/featurestore_monitoring.py | 76 + .../types/featurestore_online_service.py | 343 + .../types/featurestore_service.py | 1202 ++++ google/cloud/aiplatform_v1beta1/types/io.py | 56 + .../cloud/aiplatform_v1beta1/types/model.py | 5 +- .../cloud/aiplatform_v1beta1/types/types.py | 76 + noxfile.py | 4 +- renovate.json | 5 +- .../aiplatform_v1/test_migration_service.py | 24 +- ...est_featurestore_online_serving_service.py | 1292 ++++ .../test_featurestore_service.py | 6377 +++++++++++++++++ 66 files changed, 18649 insertions(+), 138 deletions(-) create mode 100644 docs/aiplatform_v1beta1/featurestore_online_serving_service.rst create mode 100644 docs/aiplatform_v1beta1/featurestore_service.rst create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/types/entity_type.py create mode 100644 google/cloud/aiplatform_v1beta1/types/feature.py create mode 100644 google/cloud/aiplatform_v1beta1/types/feature_selector.py create mode 100644 google/cloud/aiplatform_v1beta1/types/featurestore.py create mode 100644 google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py create mode 100644 google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py create mode 100644 google/cloud/aiplatform_v1beta1/types/featurestore_service.py create mode 100644 google/cloud/aiplatform_v1beta1/types/types.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py diff --git a/.github/header-checker-lint.yml b/.github/header-checker-lint.yml index fc281c05bd..6fe78aa798 100644 --- a/.github/header-checker-lint.yml +++ b/.github/header-checker-lint.yml @@ -1,6 +1,6 @@ {"allowedCopyrightHolders": ["Google LLC"], "allowedLicenses": ["Apache-2.0", "MIT", "BSD-3"], - "ignoreFiles": ["**/requirements.txt", "**/requirements-test.txt"], + "ignoreFiles": ["**/requirements.txt", "**/requirements-test.txt", "**/__init__.py", "samples/**/constraints.txt", "samples/**/constraints-test.txt"], "sourceFileExtensions": [ "ts", "js", diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index b703be9596..1e00173609 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -9,4 +9,3 @@ branchProtectionRules: - 'Kokoro' - 'cla/google' - 'Samples - Lint' - - 'Samples - Python 3.7' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 32302e4883..8912e9b5d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,17 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: diff --git a/docs/aiplatform_v1beta1/featurestore_online_serving_service.rst b/docs/aiplatform_v1beta1/featurestore_online_serving_service.rst new file mode 100644 index 0000000000..21013eb751 --- /dev/null +++ b/docs/aiplatform_v1beta1/featurestore_online_serving_service.rst @@ -0,0 +1,6 @@ +FeaturestoreOnlineServingService +-------------------------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/featurestore_service.rst b/docs/aiplatform_v1beta1/featurestore_service.rst new file mode 100644 index 0000000000..d05deb4c2c --- /dev/null +++ b/docs/aiplatform_v1beta1/featurestore_service.rst @@ -0,0 +1,11 @@ +FeaturestoreService +------------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.featurestore_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index 95202b1e99..7197956571 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -5,6 +5,8 @@ Services for Google Cloud Aiplatform v1beta1 API dataset_service endpoint_service + featurestore_online_serving_service + featurestore_service job_service metadata_service migration_service diff --git a/docs/conf.py b/docs/conf.py index c05116a68c..043d796523 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,17 @@ # -*- coding: utf-8 -*- +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # # google-cloud-aiplatform documentation build configuration file # diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index 950d920c5a..dfc64069bd 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.dataset_service import pagers from google.cloud.aiplatform_v1.types import annotation diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index 52109ac90b..3e14ad0e50 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.dataset_service import pagers from google.cloud.aiplatform_v1.types import annotation @@ -378,7 +378,7 @@ def create_dataset(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Creates a Dataset. Args: @@ -458,7 +458,7 @@ def create_dataset(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_dataset.Dataset, @@ -737,7 +737,7 @@ def delete_dataset(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a Dataset. Args: @@ -820,7 +820,7 @@ def delete_dataset(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -838,7 +838,7 @@ def import_data(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Imports data into a Dataset. Args: @@ -921,7 +921,7 @@ def import_data(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, dataset_service.ImportDataResponse, @@ -939,7 +939,7 @@ def export_data(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Exports data from a Dataset. Args: @@ -1021,7 +1021,7 @@ def export_data(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, dataset_service.ExportDataResponse, diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index 244c35bcba..ab1a3d3daf 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.endpoint_service import pagers from google.cloud.aiplatform_v1.types import encryption_spec diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 3b78f5902e..9be4771620 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.endpoint_service import pagers from google.cloud.aiplatform_v1.types import encryption_spec @@ -352,7 +352,7 @@ def create_endpoint(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Creates an Endpoint. Args: @@ -431,7 +431,7 @@ def create_endpoint(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_endpoint.Endpoint, @@ -707,7 +707,7 @@ def delete_endpoint(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes an Endpoint. Args: @@ -790,7 +790,7 @@ def delete_endpoint(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -809,7 +809,7 @@ def deploy_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -919,7 +919,7 @@ def deploy_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, endpoint_service.DeployModelResponse, @@ -938,7 +938,7 @@ def undeploy_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -1039,7 +1039,7 @@ def undeploy_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, endpoint_service.UndeployModelResponse, diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index e76498a85d..5d9a5d68b5 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.types import batch_prediction_job diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index 1a304de108..a3cc318097 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.types import batch_prediction_job @@ -688,7 +688,7 @@ def delete_custom_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a CustomJob. Args: @@ -771,7 +771,7 @@ def delete_custom_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -1128,7 +1128,7 @@ def delete_data_labeling_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1212,7 +1212,7 @@ def delete_data_labeling_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -1564,7 +1564,7 @@ def delete_hyperparameter_tuning_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1648,7 +1648,7 @@ def delete_hyperparameter_tuning_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -2018,7 +2018,7 @@ def delete_batch_prediction_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2103,7 +2103,7 @@ def delete_batch_prediction_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 94758701d8..9e6cf8c669 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -184,25 +184,25 @@ def parse_dataset_path(path: str) -> Dict[str,str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index a65c5df60f..f549b5e68d 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.model_service import pagers from google.cloud.aiplatform_v1.types import deployed_model_ref diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index 9d5ebc8008..f0237a4359 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.model_service import pagers from google.cloud.aiplatform_v1.types import deployed_model_ref @@ -389,7 +389,7 @@ def upload_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -471,7 +471,7 @@ def upload_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, model_service.UploadModelResponse, @@ -742,7 +742,7 @@ def delete_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -827,7 +827,7 @@ def delete_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -845,7 +845,7 @@ def export_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -932,7 +932,7 @@ def export_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, model_service.ExportModelResponse, diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index 276c0980f5..3b43bc080c 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.pipeline_service import pagers from google.cloud.aiplatform_v1.types import encryption_spec diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index fe36174dda..39d6f60f89 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.pipeline_service import pagers from google.cloud.aiplatform_v1.types import encryption_spec @@ -634,7 +634,7 @@ def delete_training_pipeline(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -718,7 +718,7 @@ def delete_training_pipeline(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index be193ead83..c05ca17005 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1.types import operation as gca_operation diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index efb32eaa6e..968bf5dbd4 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1.types import operation as gca_operation @@ -345,7 +345,7 @@ def create_specialist_pool(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -432,7 +432,7 @@ def create_specialist_pool(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_specialist_pool.SpecialistPool, @@ -629,7 +629,7 @@ def delete_specialist_pool(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -713,7 +713,7 @@ def delete_specialist_pool(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -731,7 +731,7 @@ def update_specialist_pool(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -817,7 +817,7 @@ def update_specialist_pool(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_specialist_pool.SpecialistPool, diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 0dbcbec2d6..2936282360 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -17,6 +17,8 @@ from .services.dataset_service import DatasetServiceClient from .services.endpoint_service import EndpointServiceClient +from .services.featurestore_online_serving_service import FeaturestoreOnlineServingServiceClient +from .services.featurestore_service import FeaturestoreServiceClient from .services.job_service import JobServiceClient from .services.metadata_service import MetadataServiceClient from .services.migration_service import MigrationServiceClient @@ -81,6 +83,7 @@ from .types.endpoint_service import UndeployModelRequest from .types.endpoint_service import UndeployModelResponse from .types.endpoint_service import UpdateEndpointRequest +from .types.entity_type import EntityType from .types.env_var import EnvVar from .types.event import Event from .types.execution import Execution @@ -97,13 +100,63 @@ from .types.explanation import SmoothGradConfig from .types.explanation import XraiAttribution from .types.explanation_metadata import ExplanationMetadata +from .types.feature import Feature from .types.feature_monitoring_stats import FeatureStatsAnomaly +from .types.feature_selector import FeatureSelector +from .types.feature_selector import IdMatcher +from .types.featurestore import Featurestore +from .types.featurestore_monitoring import FeaturestoreMonitoringConfig +from .types.featurestore_online_service import FeatureValue +from .types.featurestore_online_service import FeatureValueList +from .types.featurestore_online_service import ReadFeatureValuesRequest +from .types.featurestore_online_service import ReadFeatureValuesResponse +from .types.featurestore_online_service import ReadSetting +from .types.featurestore_online_service import StreamingReadFeatureValuesRequest +from .types.featurestore_service import BatchCreateFeaturesOperationMetadata +from .types.featurestore_service import BatchCreateFeaturesRequest +from .types.featurestore_service import BatchCreateFeaturesResponse +from .types.featurestore_service import BatchReadFeatureValuesOperationMetadata +from .types.featurestore_service import BatchReadFeatureValuesRequest +from .types.featurestore_service import BatchReadFeatureValuesResponse +from .types.featurestore_service import CreateEntityTypeOperationMetadata +from .types.featurestore_service import CreateEntityTypeRequest +from .types.featurestore_service import CreateFeatureOperationMetadata +from .types.featurestore_service import CreateFeatureRequest +from .types.featurestore_service import CreateFeaturestoreOperationMetadata +from .types.featurestore_service import CreateFeaturestoreRequest +from .types.featurestore_service import DeleteEntityTypeRequest +from .types.featurestore_service import DeleteFeatureRequest +from .types.featurestore_service import DeleteFeaturestoreRequest +from .types.featurestore_service import DestinationFeatureSetting +from .types.featurestore_service import FeatureValueDestination +from .types.featurestore_service import GetEntityTypeRequest +from .types.featurestore_service import GetFeatureRequest +from .types.featurestore_service import GetFeaturestoreRequest +from .types.featurestore_service import ImportFeatureValuesOperationMetadata +from .types.featurestore_service import ImportFeatureValuesRequest +from .types.featurestore_service import ImportFeatureValuesResponse +from .types.featurestore_service import ListEntityTypesRequest +from .types.featurestore_service import ListEntityTypesResponse +from .types.featurestore_service import ListFeaturesRequest +from .types.featurestore_service import ListFeaturesResponse +from .types.featurestore_service import ListFeaturestoresRequest +from .types.featurestore_service import ListFeaturestoresResponse +from .types.featurestore_service import SearchFeaturesRequest +from .types.featurestore_service import SearchFeaturesResponse +from .types.featurestore_service import UpdateEntityTypeRequest +from .types.featurestore_service import UpdateFeatureRequest +from .types.featurestore_service import UpdateFeaturestoreOperationMetadata +from .types.featurestore_service import UpdateFeaturestoreRequest from .types.hyperparameter_tuning_job import HyperparameterTuningJob +from .types.io import AvroSource from .types.io import BigQueryDestination from .types.io import BigQuerySource from .types.io import ContainerRegistryDestination +from .types.io import CsvDestination +from .types.io import CsvSource from .types.io import GcsDestination from .types.io import GcsSource +from .types.io import TFRecordDestination from .types.job_service import CancelBatchPredictionJobRequest from .types.job_service import CancelCustomJobRequest from .types.job_service import CancelDataLabelingJobRequest @@ -259,6 +312,10 @@ from .types.training_pipeline import PredefinedSplit from .types.training_pipeline import TimestampSplit from .types.training_pipeline import TrainingPipeline +from .types.types import BoolArray +from .types.types import DoubleArray +from .types.types import Int64Array +from .types.types import StringArray from .types.user_action_reference import UserActionReference from .types.vizier_service import AddTrialMeasurementRequest from .types.vizier_service import CheckTrialEarlyStoppingStateMetatdata @@ -300,13 +357,21 @@ 'Attribution', 'AutomaticResources', 'AutoscalingMetricSpec', + 'AvroSource', + 'BatchCreateFeaturesOperationMetadata', + 'BatchCreateFeaturesRequest', + 'BatchCreateFeaturesResponse', 'BatchDedicatedResources', 'BatchMigrateResourcesOperationMetadata', 'BatchMigrateResourcesRequest', 'BatchMigrateResourcesResponse', 'BatchPredictionJob', + 'BatchReadFeatureValuesOperationMetadata', + 'BatchReadFeatureValuesRequest', + 'BatchReadFeatureValuesResponse', 'BigQueryDestination', 'BigQuerySource', + 'BoolArray', 'CancelBatchPredictionJobRequest', 'CancelCustomJobRequest', 'CancelDataLabelingJobRequest', @@ -329,7 +394,13 @@ 'CreateDatasetRequest', 'CreateEndpointOperationMetadata', 'CreateEndpointRequest', + 'CreateEntityTypeOperationMetadata', + 'CreateEntityTypeRequest', 'CreateExecutionRequest', + 'CreateFeatureOperationMetadata', + 'CreateFeatureRequest', + 'CreateFeaturestoreOperationMetadata', + 'CreateFeaturestoreRequest', 'CreateHyperparameterTuningJobRequest', 'CreateMetadataSchemaRequest', 'CreateMetadataStoreOperationMetadata', @@ -340,6 +411,8 @@ 'CreateStudyRequest', 'CreateTrainingPipelineRequest', 'CreateTrialRequest', + 'CsvDestination', + 'CsvSource', 'CustomJob', 'CustomJobSpec', 'DataItem', @@ -353,6 +426,9 @@ 'DeleteDataLabelingJobRequest', 'DeleteDatasetRequest', 'DeleteEndpointRequest', + 'DeleteEntityTypeRequest', + 'DeleteFeatureRequest', + 'DeleteFeaturestoreRequest', 'DeleteHyperparameterTuningJobRequest', 'DeleteMetadataStoreOperationMetadata', 'DeleteMetadataStoreRequest', @@ -368,10 +444,13 @@ 'DeployModelResponse', 'DeployedModel', 'DeployedModelRef', + 'DestinationFeatureSetting', 'DiskSpec', + 'DoubleArray', 'EncryptionSpec', 'Endpoint', 'EndpointServiceClient', + 'EntityType', 'EnvVar', 'Event', 'Execution', @@ -390,8 +469,17 @@ 'ExportModelOperationMetadata', 'ExportModelRequest', 'ExportModelResponse', + 'Feature', 'FeatureNoiseSigma', + 'FeatureSelector', 'FeatureStatsAnomaly', + 'FeatureValue', + 'FeatureValueDestination', + 'FeatureValueList', + 'Featurestore', + 'FeaturestoreMonitoringConfig', + 'FeaturestoreOnlineServingServiceClient', + 'FeaturestoreServiceClient', 'FilterSplit', 'FractionSplit', 'GcsDestination', @@ -405,7 +493,10 @@ 'GetDataLabelingJobRequest', 'GetDatasetRequest', 'GetEndpointRequest', + 'GetEntityTypeRequest', 'GetExecutionRequest', + 'GetFeatureRequest', + 'GetFeaturestoreRequest', 'GetHyperparameterTuningJobRequest', 'GetMetadataSchemaRequest', 'GetMetadataStoreRequest', @@ -418,11 +509,16 @@ 'GetTrainingPipelineRequest', 'GetTrialRequest', 'HyperparameterTuningJob', + 'IdMatcher', 'ImportDataConfig', 'ImportDataOperationMetadata', 'ImportDataRequest', 'ImportDataResponse', + 'ImportFeatureValuesOperationMetadata', + 'ImportFeatureValuesRequest', + 'ImportFeatureValuesResponse', 'InputDataConfig', + 'Int64Array', 'IntegratedGradientsAttribution', 'JobServiceClient', 'JobState', @@ -445,8 +541,14 @@ 'ListDatasetsResponse', 'ListEndpointsRequest', 'ListEndpointsResponse', + 'ListEntityTypesRequest', + 'ListEntityTypesResponse', 'ListExecutionsRequest', 'ListExecutionsResponse', + 'ListFeaturesRequest', + 'ListFeaturesResponse', + 'ListFeaturestoresRequest', + 'ListFeaturestoresResponse', 'ListHyperparameterTuningJobsRequest', 'ListHyperparameterTuningJobsResponse', 'ListMetadataSchemasRequest', @@ -507,12 +609,17 @@ 'PythonPackageSpec', 'QueryContextLineageSubgraphRequest', 'QueryExecutionInputsAndOutputsRequest', + 'ReadFeatureValuesRequest', + 'ReadFeatureValuesResponse', + 'ReadSetting', 'ResourcesConsumed', 'ResumeModelDeploymentMonitoringJobRequest', 'SampleConfig', 'SampledShapleyAttribution', 'SamplingStrategy', 'Scheduling', + 'SearchFeaturesRequest', + 'SearchFeaturesResponse', 'SearchMigratableResourcesRequest', 'SearchMigratableResourcesResponse', 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', @@ -521,11 +628,14 @@ 'SpecialistPool', 'SpecialistPoolServiceClient', 'StopTrialRequest', + 'StreamingReadFeatureValuesRequest', + 'StringArray', 'Study', 'StudySpec', 'SuggestTrialsMetadata', 'SuggestTrialsRequest', 'SuggestTrialsResponse', + 'TFRecordDestination', 'ThresholdConfig', 'TimestampSplit', 'TrainingConfig', @@ -538,7 +648,11 @@ 'UpdateContextRequest', 'UpdateDatasetRequest', 'UpdateEndpointRequest', + 'UpdateEntityTypeRequest', 'UpdateExecutionRequest', + 'UpdateFeatureRequest', + 'UpdateFeaturestoreOperationMetadata', + 'UpdateFeaturestoreRequest', 'UpdateModelDeploymentMonitoringJobOperationMetadata', 'UpdateModelDeploymentMonitoringJobRequest', 'UpdateModelRequest', diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 2eb9ce6f7a..b5d6c0c7ee 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.types import annotation diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 9d139e6b64..0dfe93b1eb 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.types import annotation @@ -378,7 +378,7 @@ def create_dataset(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Creates a Dataset. Args: @@ -458,7 +458,7 @@ def create_dataset(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_dataset.Dataset, @@ -737,7 +737,7 @@ def delete_dataset(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a Dataset. Args: @@ -820,7 +820,7 @@ def delete_dataset(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -838,7 +838,7 @@ def import_data(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Imports data into a Dataset. Args: @@ -921,7 +921,7 @@ def import_data(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, dataset_service.ImportDataResponse, @@ -939,7 +939,7 @@ def export_data(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Exports data from a Dataset. Args: @@ -1021,7 +1021,7 @@ def export_data(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, dataset_service.ExportDataResponse, diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index daadc92c9e..0f84bafb27 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.types import encryption_spec diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 78822a9489..21e209da37 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.types import encryption_spec @@ -352,7 +352,7 @@ def create_endpoint(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Creates an Endpoint. Args: @@ -431,7 +431,7 @@ def create_endpoint(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_endpoint.Endpoint, @@ -707,7 +707,7 @@ def delete_endpoint(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes an Endpoint. Args: @@ -790,7 +790,7 @@ def delete_endpoint(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -809,7 +809,7 @@ def deploy_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -919,7 +919,7 @@ def deploy_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, endpoint_service.DeployModelResponse, @@ -938,7 +938,7 @@ def undeploy_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -1039,7 +1039,7 @@ def undeploy_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, endpoint_service.UndeployModelResponse, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py new file mode 100644 index 0000000000..d5da9ac80e --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import FeaturestoreOnlineServingServiceClient +from .async_client import FeaturestoreOnlineServingServiceAsyncClient + +__all__ = ( + 'FeaturestoreOnlineServingServiceClient', + 'FeaturestoreOnlineServingServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py new file mode 100644 index 0000000000..adb54190b0 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py @@ -0,0 +1,339 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, AsyncIterable, Awaitable, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1beta1.types import featurestore_online_service + +from .transports.base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import FeaturestoreOnlineServingServiceGrpcAsyncIOTransport +from .client import FeaturestoreOnlineServingServiceClient + + +class FeaturestoreOnlineServingServiceAsyncClient: + """A service for serving online feature values.""" + + _client: FeaturestoreOnlineServingServiceClient + + DEFAULT_ENDPOINT = FeaturestoreOnlineServingServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = FeaturestoreOnlineServingServiceClient.DEFAULT_MTLS_ENDPOINT + + entity_type_path = staticmethod(FeaturestoreOnlineServingServiceClient.entity_type_path) + parse_entity_type_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_entity_type_path) + + common_billing_account_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_project_path) + parse_common_project_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_project_path) + + common_location_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_location_path) + parse_common_location_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreOnlineServingServiceAsyncClient: The constructed client. + """ + return FeaturestoreOnlineServingServiceClient.from_service_account_info.__func__(FeaturestoreOnlineServingServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreOnlineServingServiceAsyncClient: The constructed client. + """ + return FeaturestoreOnlineServingServiceClient.from_service_account_file.__func__(FeaturestoreOnlineServingServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> FeaturestoreOnlineServingServiceTransport: + """Return the transport used by the client instance. + + Returns: + FeaturestoreOnlineServingServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(FeaturestoreOnlineServingServiceClient).get_transport_class, type(FeaturestoreOnlineServingServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, FeaturestoreOnlineServingServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the featurestore online serving service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.FeaturestoreOnlineServingServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = FeaturestoreOnlineServingServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def read_feature_values(self, + request: featurestore_online_service.ReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore_online_service.ReadFeatureValuesResponse: + r"""Reads Feature values of a specific entity of an + EntityType. For reading feature values of multiple + entities of an EntityType, please use + StreamingReadFeatureValues. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesRequest`): + The request object. Request message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + entity_type (:class:`str`): + Required. The resource name of the EntityType for the + entity being read. Value format: + ``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``. + For example, for a machine learning model predicting + user clicks on a website, an EntityType ID could be + "user". + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse: + Response message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_online_service.ReadFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.read_feature_values, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type', request.entity_type), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def streaming_read_feature_values(self, + request: featurestore_online_service.StreamingReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Awaitable[AsyncIterable[featurestore_online_service.ReadFeatureValuesResponse]]: + r"""Reads Feature values for multiple entities. Depending + on their size, data for different entities may be broken + up across multiple responses. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.StreamingReadFeatureValuesRequest`): + The request object. Request message for + [FeaturestoreOnlineServingService.StreamingFeatureValuesRead][]. + entity_type (:class:`str`): + Required. The resource name of the entities' type. Value + format: + ``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``. + For example, for a machine learning model predicting + user clicks on a website, an EntityType ID could be + "user". + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + AsyncIterable[google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse]: + Response message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_online_service.StreamingReadFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.streaming_read_feature_values, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type', request.entity_type), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'FeaturestoreOnlineServingServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py new file mode 100644 index 0000000000..7a1b71a568 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py @@ -0,0 +1,513 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Iterable, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1beta1.types import featurestore_online_service + +from .transports.base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import FeaturestoreOnlineServingServiceGrpcTransport +from .transports.grpc_asyncio import FeaturestoreOnlineServingServiceGrpcAsyncIOTransport + + +class FeaturestoreOnlineServingServiceClientMeta(type): + """Metaclass for the FeaturestoreOnlineServingService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]] + _transport_registry['grpc'] = FeaturestoreOnlineServingServiceGrpcTransport + _transport_registry['grpc_asyncio'] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport + + def get_transport_class(cls, + label: str = None, + ) -> Type[FeaturestoreOnlineServingServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class FeaturestoreOnlineServingServiceClient(metaclass=FeaturestoreOnlineServingServiceClientMeta): + """A service for serving online feature values.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreOnlineServingServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreOnlineServingServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> FeaturestoreOnlineServingServiceTransport: + """Return the transport used by the client instance. + + Returns: + FeaturestoreOnlineServingServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def entity_type_path(project: str,location: str,featurestore: str,entity_type: str,) -> str: + """Return a fully-qualified entity_type string.""" + return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) + + @staticmethod + def parse_entity_type_path(path: str) -> Dict[str,str]: + """Parse a entity_type path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, FeaturestoreOnlineServingServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the featurestore online serving service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, FeaturestoreOnlineServingServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, FeaturestoreOnlineServingServiceTransport): + # transport is a FeaturestoreOnlineServingServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def read_feature_values(self, + request: featurestore_online_service.ReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore_online_service.ReadFeatureValuesResponse: + r"""Reads Feature values of a specific entity of an + EntityType. For reading feature values of multiple + entities of an EntityType, please use + StreamingReadFeatureValues. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesRequest): + The request object. Request message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + entity_type (str): + Required. The resource name of the EntityType for the + entity being read. Value format: + ``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``. + For example, for a machine learning model predicting + user clicks on a website, an EntityType ID could be + "user". + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse: + Response message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_online_service.ReadFeatureValuesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_online_service.ReadFeatureValuesRequest): + request = featurestore_online_service.ReadFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.read_feature_values] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type', request.entity_type), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def streaming_read_feature_values(self, + request: featurestore_online_service.StreamingReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[featurestore_online_service.ReadFeatureValuesResponse]: + r"""Reads Feature values for multiple entities. Depending + on their size, data for different entities may be broken + up across multiple responses. + + Args: + request (google.cloud.aiplatform_v1beta1.types.StreamingReadFeatureValuesRequest): + The request object. Request message for + [FeaturestoreOnlineServingService.StreamingFeatureValuesRead][]. + entity_type (str): + Required. The resource name of the entities' type. Value + format: + ``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``. + For example, for a machine learning model predicting + user clicks on a website, an EntityType ID could be + "user". + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse]: + Response message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_online_service.StreamingReadFeatureValuesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_online_service.StreamingReadFeatureValuesRequest): + request = featurestore_online_service.StreamingReadFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.streaming_read_feature_values] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type', request.entity_type), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'FeaturestoreOnlineServingServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py new file mode 100644 index 0000000000..e3326680c7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import FeaturestoreOnlineServingServiceTransport +from .grpc import FeaturestoreOnlineServingServiceGrpcTransport +from .grpc_asyncio import FeaturestoreOnlineServingServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]] +_transport_registry['grpc'] = FeaturestoreOnlineServingServiceGrpcTransport +_transport_registry['grpc_asyncio'] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport + +__all__ = ( + 'FeaturestoreOnlineServingServiceTransport', + 'FeaturestoreOnlineServingServiceGrpcTransport', + 'FeaturestoreOnlineServingServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py new file mode 100644 index 0000000000..8db9596f98 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import featurestore_online_service + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class FeaturestoreOnlineServingServiceTransport(abc.ABC): + """Abstract transport class for FeaturestoreOnlineServingService.""" + + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) + + def __init__( + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + + # Save the credentials. + self._credentials = credentials + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.read_feature_values: gapic_v1.method.wrap_method( + self.read_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.streaming_read_feature_values: gapic_v1.method.wrap_method( + self.streaming_read_feature_values, + default_timeout=None, + client_info=client_info, + ), + + } + + @property + def read_feature_values(self) -> typing.Callable[ + [featurestore_online_service.ReadFeatureValuesRequest], + typing.Union[ + featurestore_online_service.ReadFeatureValuesResponse, + typing.Awaitable[featurestore_online_service.ReadFeatureValuesResponse] + ]]: + raise NotImplementedError() + + @property + def streaming_read_feature_values(self) -> typing.Callable[ + [featurestore_online_service.StreamingReadFeatureValuesRequest], + typing.Union[ + featurestore_online_service.ReadFeatureValuesResponse, + typing.Awaitable[featurestore_online_service.ReadFeatureValuesResponse] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'FeaturestoreOnlineServingServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py new file mode 100644 index 0000000000..6ba3a31748 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import featurestore_online_service + +from .base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO + + +class FeaturestoreOnlineServingServiceGrpcTransport(FeaturestoreOnlineServingServiceTransport): + """gRPC backend transport for FeaturestoreOnlineServingService. + + A service for serving online feature values. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def read_feature_values(self) -> Callable[ + [featurestore_online_service.ReadFeatureValuesRequest], + featurestore_online_service.ReadFeatureValuesResponse]: + r"""Return a callable for the read feature values method over gRPC. + + Reads Feature values of a specific entity of an + EntityType. For reading feature values of multiple + entities of an EntityType, please use + StreamingReadFeatureValues. + + Returns: + Callable[[~.ReadFeatureValuesRequest], + ~.ReadFeatureValuesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'read_feature_values' not in self._stubs: + self._stubs['read_feature_values'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/ReadFeatureValues', + request_serializer=featurestore_online_service.ReadFeatureValuesRequest.serialize, + response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, + ) + return self._stubs['read_feature_values'] + + @property + def streaming_read_feature_values(self) -> Callable[ + [featurestore_online_service.StreamingReadFeatureValuesRequest], + featurestore_online_service.ReadFeatureValuesResponse]: + r"""Return a callable for the streaming read feature values method over gRPC. + + Reads Feature values for multiple entities. Depending + on their size, data for different entities may be broken + up across multiple responses. + + Returns: + Callable[[~.StreamingReadFeatureValuesRequest], + ~.ReadFeatureValuesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'streaming_read_feature_values' not in self._stubs: + self._stubs['streaming_read_feature_values'] = self.grpc_channel.unary_stream( + '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/StreamingReadFeatureValues', + request_serializer=featurestore_online_service.StreamingReadFeatureValuesRequest.serialize, + response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, + ) + return self._stubs['streaming_read_feature_values'] + + +__all__ = ( + 'FeaturestoreOnlineServingServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..bd03ab6626 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import featurestore_online_service + +from .base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import FeaturestoreOnlineServingServiceGrpcTransport + + +class FeaturestoreOnlineServingServiceGrpcAsyncIOTransport(FeaturestoreOnlineServingServiceTransport): + """gRPC AsyncIO backend transport for FeaturestoreOnlineServingService. + + A service for serving online feature values. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def read_feature_values(self) -> Callable[ + [featurestore_online_service.ReadFeatureValuesRequest], + Awaitable[featurestore_online_service.ReadFeatureValuesResponse]]: + r"""Return a callable for the read feature values method over gRPC. + + Reads Feature values of a specific entity of an + EntityType. For reading feature values of multiple + entities of an EntityType, please use + StreamingReadFeatureValues. + + Returns: + Callable[[~.ReadFeatureValuesRequest], + Awaitable[~.ReadFeatureValuesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'read_feature_values' not in self._stubs: + self._stubs['read_feature_values'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/ReadFeatureValues', + request_serializer=featurestore_online_service.ReadFeatureValuesRequest.serialize, + response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, + ) + return self._stubs['read_feature_values'] + + @property + def streaming_read_feature_values(self) -> Callable[ + [featurestore_online_service.StreamingReadFeatureValuesRequest], + Awaitable[featurestore_online_service.ReadFeatureValuesResponse]]: + r"""Return a callable for the streaming read feature values method over gRPC. + + Reads Feature values for multiple entities. Depending + on their size, data for different entities may be broken + up across multiple responses. + + Returns: + Callable[[~.StreamingReadFeatureValuesRequest], + Awaitable[~.ReadFeatureValuesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'streaming_read_feature_values' not in self._stubs: + self._stubs['streaming_read_feature_values'] = self.grpc_channel.unary_stream( + '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/StreamingReadFeatureValues', + request_serializer=featurestore_online_service.StreamingReadFeatureValuesRequest.serialize, + response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, + ) + return self._stubs['streaming_read_feature_values'] + + +__all__ = ( + 'FeaturestoreOnlineServingServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py new file mode 100644 index 0000000000..e3d630a7cc --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import FeaturestoreServiceClient +from .async_client import FeaturestoreServiceAsyncClient + +__all__ = ( + 'FeaturestoreServiceClient', + 'FeaturestoreServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py new file mode 100644 index 0000000000..e671bbfa1c --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -0,0 +1,2019 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.featurestore_service import pagers +from google.cloud.aiplatform_v1beta1.types import entity_type +from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type +from google.cloud.aiplatform_v1beta1.types import feature +from google.cloud.aiplatform_v1beta1.types import feature as gca_feature +from google.cloud.aiplatform_v1beta1.types import feature_monitoring_stats +from google.cloud.aiplatform_v1beta1.types import featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore as gca_featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore_monitoring +from google.cloud.aiplatform_v1beta1.types import featurestore_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import FeaturestoreServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import FeaturestoreServiceGrpcAsyncIOTransport +from .client import FeaturestoreServiceClient + + +class FeaturestoreServiceAsyncClient: + """The service that handles CRUD and List for resources for + Featurestore. + """ + + _client: FeaturestoreServiceClient + + DEFAULT_ENDPOINT = FeaturestoreServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = FeaturestoreServiceClient.DEFAULT_MTLS_ENDPOINT + + entity_type_path = staticmethod(FeaturestoreServiceClient.entity_type_path) + parse_entity_type_path = staticmethod(FeaturestoreServiceClient.parse_entity_type_path) + feature_path = staticmethod(FeaturestoreServiceClient.feature_path) + parse_feature_path = staticmethod(FeaturestoreServiceClient.parse_feature_path) + featurestore_path = staticmethod(FeaturestoreServiceClient.featurestore_path) + parse_featurestore_path = staticmethod(FeaturestoreServiceClient.parse_featurestore_path) + + common_billing_account_path = staticmethod(FeaturestoreServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(FeaturestoreServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(FeaturestoreServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(FeaturestoreServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(FeaturestoreServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(FeaturestoreServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(FeaturestoreServiceClient.common_project_path) + parse_common_project_path = staticmethod(FeaturestoreServiceClient.parse_common_project_path) + + common_location_path = staticmethod(FeaturestoreServiceClient.common_location_path) + parse_common_location_path = staticmethod(FeaturestoreServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreServiceAsyncClient: The constructed client. + """ + return FeaturestoreServiceClient.from_service_account_info.__func__(FeaturestoreServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreServiceAsyncClient: The constructed client. + """ + return FeaturestoreServiceClient.from_service_account_file.__func__(FeaturestoreServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> FeaturestoreServiceTransport: + """Return the transport used by the client instance. + + Returns: + FeaturestoreServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(FeaturestoreServiceClient).get_transport_class, type(FeaturestoreServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, FeaturestoreServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the featurestore service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.FeaturestoreServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = FeaturestoreServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_featurestore(self, + request: featurestore_service.CreateFeaturestoreRequest = None, + *, + parent: str = None, + featurestore: gca_featurestore.Featurestore = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a new Featurestore in a given project and + location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateFeaturestoreRequest`): + The request object. Request message for + ``FeaturestoreService.CreateFeaturestore``. + parent (:class:`str`): + Required. The resource name of the Location to create + Featurestores. Format: + ``projects/{project}/locations/{location}'`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + featurestore (:class:`google.cloud.aiplatform_v1beta1.types.Featurestore`): + Required. The Featurestore to create. + This corresponds to the ``featurestore`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.Featurestore` + Featurestore configuration information on how the + Featurestore is configured. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, featurestore]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.CreateFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if featurestore is not None: + request.featurestore = featurestore + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_featurestore, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_featurestore.Featurestore, + metadata_type=featurestore_service.CreateFeaturestoreOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_featurestore(self, + request: featurestore_service.GetFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore.Featurestore: + r"""Gets details of a single Featurestore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetFeaturestoreRequest`): + The request object. Request message for + ``FeaturestoreService.GetFeaturestore``. + name (:class:`str`): + Required. The name of the + Featurestore resource. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Featurestore: + Featurestore configuration + information on how the Featurestore is + configured. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.GetFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_featurestore, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_featurestores(self, + request: featurestore_service.ListFeaturestoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturestoresAsyncPager: + r"""Lists Featurestores in a given project and location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListFeaturestoresRequest`): + The request object. Request message for + ``FeaturestoreService.ListFeaturestores``. + parent (:class:`str`): + Required. The resource name of the Location to list + Featurestores. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.ListFeaturestoresAsyncPager: + Response message for + ``FeaturestoreService.ListFeaturestores``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.ListFeaturestoresRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_featurestores, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListFeaturestoresAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_featurestore(self, + request: featurestore_service.UpdateFeaturestoreRequest = None, + *, + featurestore: gca_featurestore.Featurestore = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates the parameters of a single Featurestore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateFeaturestoreRequest`): + The request object. Request message for + ``FeaturestoreService.UpdateFeaturestore``. + featurestore (:class:`google.cloud.aiplatform_v1beta1.types.Featurestore`): + Required. The Featurestore's ``name`` field is used to + identify the Featurestore to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``featurestore`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Field mask is used to specify the fields to be + overwritten in the Featurestore resource by the update. + The fields specified in the update_mask are relative to + the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then only the non-empty fields present in + the request will be overwritten. Set the update_mask to + ``*`` to override all fields. + + Updatable fields: + + - ``display_name`` + - ``labels`` + - ``online_serving_config.fixed_node_count`` + - ``online_serving_config.max_online_serving_size`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.Featurestore` + Featurestore configuration information on how the + Featurestore is configured. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([featurestore, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.UpdateFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if featurestore is not None: + request.featurestore = featurestore + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_featurestore, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('featurestore.name', request.featurestore.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_featurestore.Featurestore, + metadata_type=featurestore_service.UpdateFeaturestoreOperationMetadata, + ) + + # Done; return the response. + return response + + async def delete_featurestore(self, + request: featurestore_service.DeleteFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a single Featurestore. The Featurestore must not contain + any EntityTypes or ``force`` must be set to true for the request + to succeed. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteFeaturestoreRequest`): + The request object. Request message for + ``FeaturestoreService.DeleteFeaturestore``. + name (:class:`str`): + Required. The name of the Featurestore to be deleted. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.DeleteFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_featurestore, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def create_entity_type(self, + request: featurestore_service.CreateEntityTypeRequest = None, + *, + parent: str = None, + entity_type: gca_entity_type.EntityType = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a new EntityType in a given Featurestore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateEntityTypeRequest`): + The request object. Request message for + ``FeaturestoreService.CreateEntityType``. + parent (:class:`str`): + Required. The resource name of the Featurestore to + create EntityTypes. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + entity_type (:class:`google.cloud.aiplatform_v1beta1.types.EntityType`): + The EntityType to create. + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.EntityType` An entity type is a type of object in a system that needs to be modeled and + have stored information about. For example, driver is + an entity type, and driver0 is an instance of an + entity type driver. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.CreateEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_entity_type, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_entity_type.EntityType, + metadata_type=featurestore_service.CreateEntityTypeOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_entity_type(self, + request: featurestore_service.GetEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> entity_type.EntityType: + r"""Gets details of a single EntityType. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetEntityTypeRequest`): + The request object. Request message for + ``FeaturestoreService.GetEntityType``. + name (:class:`str`): + Required. The name of the EntityType resource. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.EntityType: + An entity type is a type of object in + a system that needs to be modeled and + have stored information about. For + example, driver is an entity type, and + driver0 is an instance of an entity type + driver. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.GetEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_entity_type, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_entity_types(self, + request: featurestore_service.ListEntityTypesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEntityTypesAsyncPager: + r"""Lists EntityTypes in a given Featurestore. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListEntityTypesRequest`): + The request object. Request message for + ``FeaturestoreService.ListEntityTypes``. + parent (:class:`str`): + Required. The resource name of the Featurestore to list + EntityTypes. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.ListEntityTypesAsyncPager: + Response message for + ``FeaturestoreService.ListEntityTypes``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.ListEntityTypesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_entity_types, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListEntityTypesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_entity_type(self, + request: featurestore_service.UpdateEntityTypeRequest = None, + *, + entity_type: gca_entity_type.EntityType = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_entity_type.EntityType: + r"""Updates the parameters of a single EntityType. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateEntityTypeRequest`): + The request object. Request message for + [FeaturestoreService.UpdateEntityTypes][]. + entity_type (:class:`google.cloud.aiplatform_v1beta1.types.EntityType`): + Required. The EntityType's ``name`` field is used to + identify the EntityType to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Field mask is used to specify the fields to be + overwritten in the EntityType resource by the update. + The fields specified in the update_mask are relative to + the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then only the non-empty fields present in + the request will be overwritten. Set the update_mask to + ``*`` to override all fields. + + Updatable fields: + + - ``description`` + - ``labels`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.EntityType: + An entity type is a type of object in + a system that needs to be modeled and + have stored information about. For + example, driver is an entity type, and + driver0 is an instance of an entity type + driver. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.UpdateEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_entity_type, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type.name', request.entity_type.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_entity_type(self, + request: featurestore_service.DeleteEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a single EntityType. The EntityType must not have any + Features or ``force`` must be set to true for the request to + succeed. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteEntityTypeRequest`): + The request object. Request message for + [FeaturestoreService.DeleteEntityTypes][]. + name (:class:`str`): + Required. The name of the EntityType to be deleted. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.DeleteEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_entity_type, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def create_feature(self, + request: featurestore_service.CreateFeatureRequest = None, + *, + parent: str = None, + feature: gca_feature.Feature = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a new Feature in a given EntityType. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateFeatureRequest`): + The request object. Request message for + ``FeaturestoreService.CreateFeature``. + parent (:class:`str`): + Required. The resource name of the EntityType to create + a Feature. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + feature (:class:`google.cloud.aiplatform_v1beta1.types.Feature`): + Required. The Feature to create. + This corresponds to the ``feature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Feature` Feature Metadata information that describes an attribute of an entity type. + For example, apple is an entity type, and color is a + feature that describes apple. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, feature]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.CreateFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if feature is not None: + request.feature = feature + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_feature, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_feature.Feature, + metadata_type=featurestore_service.CreateFeatureOperationMetadata, + ) + + # Done; return the response. + return response + + async def batch_create_features(self, + request: featurestore_service.BatchCreateFeaturesRequest = None, + *, + parent: str = None, + requests: Sequence[featurestore_service.CreateFeatureRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a batch of Features in a given EntityType. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.BatchCreateFeaturesRequest`): + The request object. Request message for + ``FeaturestoreService.BatchCreateFeatures``. + parent (:class:`str`): + Required. The resource name of the EntityType to create + the batch of Features under. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + requests (:class:`Sequence[google.cloud.aiplatform_v1beta1.types.CreateFeatureRequest]`): + Required. The request message specifying the Features to + create. All Features must be created under the same + parent EntityType. The ``parent`` field in each child + request message can be omitted. If ``parent`` is set in + a child request, then the value must match the + ``parent`` value in this request message. + + This corresponds to the ``requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.BatchCreateFeaturesResponse` + Response message for + ``FeaturestoreService.BatchCreateFeatures``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, requests]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.BatchCreateFeaturesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + if requests: + request.requests.extend(requests) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_create_features, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + featurestore_service.BatchCreateFeaturesResponse, + metadata_type=featurestore_service.BatchCreateFeaturesOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_feature(self, + request: featurestore_service.GetFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> feature.Feature: + r"""Gets details of a single Feature. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetFeatureRequest`): + The request object. Request message for + ``FeaturestoreService.GetFeature``. + name (:class:`str`): + Required. The name of the Feature resource. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Feature: + Feature Metadata information that + describes an attribute of an entity + type. For example, apple is an entity + type, and color is a feature that + describes apple. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.GetFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_feature, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_features(self, + request: featurestore_service.ListFeaturesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturesAsyncPager: + r"""Lists Features in a given EntityType. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListFeaturesRequest`): + The request object. Request message for + ``FeaturestoreService.ListFeatures``. + parent (:class:`str`): + Required. The resource name of the Location to list + Features. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.ListFeaturesAsyncPager: + Response message for + ``FeaturestoreService.ListFeatures``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.ListFeaturesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_features, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListFeaturesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_feature(self, + request: featurestore_service.UpdateFeatureRequest = None, + *, + feature: gca_feature.Feature = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_feature.Feature: + r"""Updates the parameters of a single Feature. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateFeatureRequest`): + The request object. Request message for + ``FeaturestoreService.UpdateFeature``. + feature (:class:`google.cloud.aiplatform_v1beta1.types.Feature`): + Required. The Feature's ``name`` field is used to + identify the Feature to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}`` + + This corresponds to the ``feature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Field mask is used to specify the fields to be + overwritten in the Features resource by the update. The + fields specified in the update_mask are relative to the + resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then only the non-empty fields present in + the request will be overwritten. Set the update_mask to + ``*`` to override all fields. + + Updatable fields: + + - ``description`` + - ``labels`` + - ``monitoring_config.snapshot_analysis.disabled`` + - ``monitoring_config.snapshot_analysis.monitoring_interval`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Feature: + Feature Metadata information that + describes an attribute of an entity + type. For example, apple is an entity + type, and color is a feature that + describes apple. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([feature, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.UpdateFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if feature is not None: + request.feature = feature + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_feature, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('feature.name', request.feature.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_feature(self, + request: featurestore_service.DeleteFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a single Feature. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteFeatureRequest`): + The request object. Request message for + ``FeaturestoreService.DeleteFeature``. + name (:class:`str`): + Required. The name of the Features to be deleted. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.DeleteFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_feature, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def import_feature_values(self, + request: featurestore_service.ImportFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Imports Feature values into the Featurestore from a + source storage. + The progress of the import is tracked by the returned + operation. The imported features are guaranteed to be + visible to subsequent read operations after the + operation is marked as successfully done. + If an import operation fails, the Feature values + returned from reads and exports may be inconsistent. If + consistency is required, the caller must retry the same + import request again and wait till the new operation + returned is marked as successfully done. + There are also scenarios where the caller can cause + inconsistency. + - Source data for import contains multiple distinct + Feature values for the same entity ID and timestamp. + - Source is modified during an import. This includes + adding, updating, or removing source data and/or + metadata. Examples of updating metadata include but are + not limited to changing storage location, storage class, + or retention policy. + - Online serving cluster is under-provisioned. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ImportFeatureValuesRequest`): + The request object. Request message for + ``FeaturestoreService.ImportFeatureValues``. + entity_type (:class:`str`): + Required. The resource name of the EntityType grouping + the Features for which values are being imported. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}`` + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.ImportFeatureValuesResponse` + Response message for + ``FeaturestoreService.ImportFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.ImportFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.import_feature_values, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type', request.entity_type), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + featurestore_service.ImportFeatureValuesResponse, + metadata_type=featurestore_service.ImportFeatureValuesOperationMetadata, + ) + + # Done; return the response. + return response + + async def batch_read_feature_values(self, + request: featurestore_service.BatchReadFeatureValuesRequest = None, + *, + featurestore: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Batch reads Feature values from a Featurestore. + This API enables batch reading Feature values, where + each read instance in the batch may read Feature values + of entities from one or more EntityTypes. Point-in-time + correctness is guaranteed for Feature values of each + read instance as of each instance's read timestamp. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.BatchReadFeatureValuesRequest`): + The request object. Request message for + ``FeaturestoreService.BatchReadFeatureValues``. + featurestore (:class:`str`): + Required. The resource name of the Featurestore from + which to query Feature values. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``featurestore`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.BatchReadFeatureValuesResponse` + Response message for + ``FeaturestoreService.BatchReadFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([featurestore]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.BatchReadFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if featurestore is not None: + request.featurestore = featurestore + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_read_feature_values, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('featurestore', request.featurestore), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + featurestore_service.BatchReadFeatureValuesResponse, + metadata_type=featurestore_service.BatchReadFeatureValuesOperationMetadata, + ) + + # Done; return the response. + return response + + async def search_features(self, + request: featurestore_service.SearchFeaturesRequest = None, + *, + location: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchFeaturesAsyncPager: + r"""Searches Features matching a query in a given + project. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.SearchFeaturesRequest`): + The request object. Request message for + ``FeaturestoreService.SearchFeatures``. + location (:class:`str`): + Required. The resource name of the Location to search + Features. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``location`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.SearchFeaturesAsyncPager: + Response message for + ``FeaturestoreService.SearchFeatures``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([location]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = featurestore_service.SearchFeaturesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if location is not None: + request.location = location + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.search_features, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('location', request.location), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.SearchFeaturesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'FeaturestoreServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py new file mode 100644 index 0000000000..c566a9b24e --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -0,0 +1,2227 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.featurestore_service import pagers +from google.cloud.aiplatform_v1beta1.types import entity_type +from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type +from google.cloud.aiplatform_v1beta1.types import feature +from google.cloud.aiplatform_v1beta1.types import feature as gca_feature +from google.cloud.aiplatform_v1beta1.types import feature_monitoring_stats +from google.cloud.aiplatform_v1beta1.types import featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore as gca_featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore_monitoring +from google.cloud.aiplatform_v1beta1.types import featurestore_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import FeaturestoreServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import FeaturestoreServiceGrpcTransport +from .transports.grpc_asyncio import FeaturestoreServiceGrpcAsyncIOTransport + + +class FeaturestoreServiceClientMeta(type): + """Metaclass for the FeaturestoreService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreServiceTransport]] + _transport_registry['grpc'] = FeaturestoreServiceGrpcTransport + _transport_registry['grpc_asyncio'] = FeaturestoreServiceGrpcAsyncIOTransport + + def get_transport_class(cls, + label: str = None, + ) -> Type[FeaturestoreServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class FeaturestoreServiceClient(metaclass=FeaturestoreServiceClientMeta): + """The service that handles CRUD and List for resources for + Featurestore. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FeaturestoreServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> FeaturestoreServiceTransport: + """Return the transport used by the client instance. + + Returns: + FeaturestoreServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def entity_type_path(project: str,location: str,featurestore: str,entity_type: str,) -> str: + """Return a fully-qualified entity_type string.""" + return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) + + @staticmethod + def parse_entity_type_path(path: str) -> Dict[str,str]: + """Parse a entity_type path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def feature_path(project: str,location: str,featurestore: str,entity_type: str,feature: str,) -> str: + """Return a fully-qualified feature string.""" + return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, feature=feature, ) + + @staticmethod + def parse_feature_path(path: str) -> Dict[str,str]: + """Parse a feature path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)/features/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def featurestore_path(project: str,location: str,featurestore: str,) -> str: + """Return a fully-qualified featurestore string.""" + return "projects/{project}/locations/{location}/featurestores/{featurestore}".format(project=project, location=location, featurestore=featurestore, ) + + @staticmethod + def parse_featurestore_path(path: str) -> Dict[str,str]: + """Parse a featurestore path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, FeaturestoreServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the featurestore service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, FeaturestoreServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, FeaturestoreServiceTransport): + # transport is a FeaturestoreServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_featurestore(self, + request: featurestore_service.CreateFeaturestoreRequest = None, + *, + parent: str = None, + featurestore: gca_featurestore.Featurestore = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Creates a new Featurestore in a given project and + location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateFeaturestoreRequest): + The request object. Request message for + ``FeaturestoreService.CreateFeaturestore``. + parent (str): + Required. The resource name of the Location to create + Featurestores. Format: + ``projects/{project}/locations/{location}'`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + featurestore (google.cloud.aiplatform_v1beta1.types.Featurestore): + Required. The Featurestore to create. + This corresponds to the ``featurestore`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.Featurestore` + Featurestore configuration information on how the + Featurestore is configured. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, featurestore]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.CreateFeaturestoreRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.CreateFeaturestoreRequest): + request = featurestore_service.CreateFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if featurestore is not None: + request.featurestore = featurestore + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_featurestore] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_featurestore.Featurestore, + metadata_type=featurestore_service.CreateFeaturestoreOperationMetadata, + ) + + # Done; return the response. + return response + + def get_featurestore(self, + request: featurestore_service.GetFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore.Featurestore: + r"""Gets details of a single Featurestore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetFeaturestoreRequest): + The request object. Request message for + ``FeaturestoreService.GetFeaturestore``. + name (str): + Required. The name of the + Featurestore resource. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Featurestore: + Featurestore configuration + information on how the Featurestore is + configured. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.GetFeaturestoreRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.GetFeaturestoreRequest): + request = featurestore_service.GetFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_featurestore] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_featurestores(self, + request: featurestore_service.ListFeaturestoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturestoresPager: + r"""Lists Featurestores in a given project and location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListFeaturestoresRequest): + The request object. Request message for + ``FeaturestoreService.ListFeaturestores``. + parent (str): + Required. The resource name of the Location to list + Featurestores. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.ListFeaturestoresPager: + Response message for + ``FeaturestoreService.ListFeaturestores``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.ListFeaturestoresRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.ListFeaturestoresRequest): + request = featurestore_service.ListFeaturestoresRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_featurestores] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListFeaturestoresPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_featurestore(self, + request: featurestore_service.UpdateFeaturestoreRequest = None, + *, + featurestore: gca_featurestore.Featurestore = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Updates the parameters of a single Featurestore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateFeaturestoreRequest): + The request object. Request message for + ``FeaturestoreService.UpdateFeaturestore``. + featurestore (google.cloud.aiplatform_v1beta1.types.Featurestore): + Required. The Featurestore's ``name`` field is used to + identify the Featurestore to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``featurestore`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Field mask is used to specify the fields to be + overwritten in the Featurestore resource by the update. + The fields specified in the update_mask are relative to + the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then only the non-empty fields present in + the request will be overwritten. Set the update_mask to + ``*`` to override all fields. + + Updatable fields: + + - ``display_name`` + - ``labels`` + - ``online_serving_config.fixed_node_count`` + - ``online_serving_config.max_online_serving_size`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.Featurestore` + Featurestore configuration information on how the + Featurestore is configured. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([featurestore, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.UpdateFeaturestoreRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.UpdateFeaturestoreRequest): + request = featurestore_service.UpdateFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if featurestore is not None: + request.featurestore = featurestore + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_featurestore] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('featurestore.name', request.featurestore.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_featurestore.Featurestore, + metadata_type=featurestore_service.UpdateFeaturestoreOperationMetadata, + ) + + # Done; return the response. + return response + + def delete_featurestore(self, + request: featurestore_service.DeleteFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a single Featurestore. The Featurestore must not contain + any EntityTypes or ``force`` must be set to true for the request + to succeed. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteFeaturestoreRequest): + The request object. Request message for + ``FeaturestoreService.DeleteFeaturestore``. + name (str): + Required. The name of the Featurestore to be deleted. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.DeleteFeaturestoreRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.DeleteFeaturestoreRequest): + request = featurestore_service.DeleteFeaturestoreRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_featurestore] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def create_entity_type(self, + request: featurestore_service.CreateEntityTypeRequest = None, + *, + parent: str = None, + entity_type: gca_entity_type.EntityType = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Creates a new EntityType in a given Featurestore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateEntityTypeRequest): + The request object. Request message for + ``FeaturestoreService.CreateEntityType``. + parent (str): + Required. The resource name of the Featurestore to + create EntityTypes. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + entity_type (google.cloud.aiplatform_v1beta1.types.EntityType): + The EntityType to create. + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.EntityType` An entity type is a type of object in a system that needs to be modeled and + have stored information about. For example, driver is + an entity type, and driver0 is an instance of an + entity type driver. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.CreateEntityTypeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.CreateEntityTypeRequest): + request = featurestore_service.CreateEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_entity_type] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_entity_type.EntityType, + metadata_type=featurestore_service.CreateEntityTypeOperationMetadata, + ) + + # Done; return the response. + return response + + def get_entity_type(self, + request: featurestore_service.GetEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> entity_type.EntityType: + r"""Gets details of a single EntityType. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetEntityTypeRequest): + The request object. Request message for + ``FeaturestoreService.GetEntityType``. + name (str): + Required. The name of the EntityType resource. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.EntityType: + An entity type is a type of object in + a system that needs to be modeled and + have stored information about. For + example, driver is an entity type, and + driver0 is an instance of an entity type + driver. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.GetEntityTypeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.GetEntityTypeRequest): + request = featurestore_service.GetEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_entity_type] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_entity_types(self, + request: featurestore_service.ListEntityTypesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEntityTypesPager: + r"""Lists EntityTypes in a given Featurestore. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListEntityTypesRequest): + The request object. Request message for + ``FeaturestoreService.ListEntityTypes``. + parent (str): + Required. The resource name of the Featurestore to list + EntityTypes. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.ListEntityTypesPager: + Response message for + ``FeaturestoreService.ListEntityTypes``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.ListEntityTypesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.ListEntityTypesRequest): + request = featurestore_service.ListEntityTypesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_entity_types] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListEntityTypesPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_entity_type(self, + request: featurestore_service.UpdateEntityTypeRequest = None, + *, + entity_type: gca_entity_type.EntityType = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_entity_type.EntityType: + r"""Updates the parameters of a single EntityType. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateEntityTypeRequest): + The request object. Request message for + [FeaturestoreService.UpdateEntityTypes][]. + entity_type (google.cloud.aiplatform_v1beta1.types.EntityType): + Required. The EntityType's ``name`` field is used to + identify the EntityType to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Field mask is used to specify the fields to be + overwritten in the EntityType resource by the update. + The fields specified in the update_mask are relative to + the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then only the non-empty fields present in + the request will be overwritten. Set the update_mask to + ``*`` to override all fields. + + Updatable fields: + + - ``description`` + - ``labels`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.EntityType: + An entity type is a type of object in + a system that needs to be modeled and + have stored information about. For + example, driver is an entity type, and + driver0 is an instance of an entity type + driver. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.UpdateEntityTypeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.UpdateEntityTypeRequest): + request = featurestore_service.UpdateEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_entity_type] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type.name', request.entity_type.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_entity_type(self, + request: featurestore_service.DeleteEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a single EntityType. The EntityType must not have any + Features or ``force`` must be set to true for the request to + succeed. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteEntityTypeRequest): + The request object. Request message for + [FeaturestoreService.DeleteEntityTypes][]. + name (str): + Required. The name of the EntityType to be deleted. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.DeleteEntityTypeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.DeleteEntityTypeRequest): + request = featurestore_service.DeleteEntityTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_entity_type] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def create_feature(self, + request: featurestore_service.CreateFeatureRequest = None, + *, + parent: str = None, + feature: gca_feature.Feature = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Creates a new Feature in a given EntityType. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateFeatureRequest): + The request object. Request message for + ``FeaturestoreService.CreateFeature``. + parent (str): + Required. The resource name of the EntityType to create + a Feature. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + feature (google.cloud.aiplatform_v1beta1.types.Feature): + Required. The Feature to create. + This corresponds to the ``feature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Feature` Feature Metadata information that describes an attribute of an entity type. + For example, apple is an entity type, and color is a + feature that describes apple. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, feature]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.CreateFeatureRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.CreateFeatureRequest): + request = featurestore_service.CreateFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if feature is not None: + request.feature = feature + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_feature] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_feature.Feature, + metadata_type=featurestore_service.CreateFeatureOperationMetadata, + ) + + # Done; return the response. + return response + + def batch_create_features(self, + request: featurestore_service.BatchCreateFeaturesRequest = None, + *, + parent: str = None, + requests: Sequence[featurestore_service.CreateFeatureRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Creates a batch of Features in a given EntityType. + + Args: + request (google.cloud.aiplatform_v1beta1.types.BatchCreateFeaturesRequest): + The request object. Request message for + ``FeaturestoreService.BatchCreateFeatures``. + parent (str): + Required. The resource name of the EntityType to create + the batch of Features under. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + requests (Sequence[google.cloud.aiplatform_v1beta1.types.CreateFeatureRequest]): + Required. The request message specifying the Features to + create. All Features must be created under the same + parent EntityType. The ``parent`` field in each child + request message can be omitted. If ``parent`` is set in + a child request, then the value must match the + ``parent`` value in this request message. + + This corresponds to the ``requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.BatchCreateFeaturesResponse` + Response message for + ``FeaturestoreService.BatchCreateFeatures``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, requests]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.BatchCreateFeaturesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.BatchCreateFeaturesRequest): + request = featurestore_service.BatchCreateFeaturesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if requests is not None: + request.requests = requests + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_create_features] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + featurestore_service.BatchCreateFeaturesResponse, + metadata_type=featurestore_service.BatchCreateFeaturesOperationMetadata, + ) + + # Done; return the response. + return response + + def get_feature(self, + request: featurestore_service.GetFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> feature.Feature: + r"""Gets details of a single Feature. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetFeatureRequest): + The request object. Request message for + ``FeaturestoreService.GetFeature``. + name (str): + Required. The name of the Feature resource. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Feature: + Feature Metadata information that + describes an attribute of an entity + type. For example, apple is an entity + type, and color is a feature that + describes apple. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.GetFeatureRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.GetFeatureRequest): + request = featurestore_service.GetFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_feature] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_features(self, + request: featurestore_service.ListFeaturesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturesPager: + r"""Lists Features in a given EntityType. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListFeaturesRequest): + The request object. Request message for + ``FeaturestoreService.ListFeatures``. + parent (str): + Required. The resource name of the Location to list + Features. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.ListFeaturesPager: + Response message for + ``FeaturestoreService.ListFeatures``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.ListFeaturesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.ListFeaturesRequest): + request = featurestore_service.ListFeaturesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_features] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListFeaturesPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_feature(self, + request: featurestore_service.UpdateFeatureRequest = None, + *, + feature: gca_feature.Feature = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_feature.Feature: + r"""Updates the parameters of a single Feature. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateFeatureRequest): + The request object. Request message for + ``FeaturestoreService.UpdateFeature``. + feature (google.cloud.aiplatform_v1beta1.types.Feature): + Required. The Feature's ``name`` field is used to + identify the Feature to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}`` + + This corresponds to the ``feature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Field mask is used to specify the fields to be + overwritten in the Features resource by the update. The + fields specified in the update_mask are relative to the + resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then only the non-empty fields present in + the request will be overwritten. Set the update_mask to + ``*`` to override all fields. + + Updatable fields: + + - ``description`` + - ``labels`` + - ``monitoring_config.snapshot_analysis.disabled`` + - ``monitoring_config.snapshot_analysis.monitoring_interval`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Feature: + Feature Metadata information that + describes an attribute of an entity + type. For example, apple is an entity + type, and color is a feature that + describes apple. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([feature, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.UpdateFeatureRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.UpdateFeatureRequest): + request = featurestore_service.UpdateFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if feature is not None: + request.feature = feature + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_feature] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('feature.name', request.feature.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_feature(self, + request: featurestore_service.DeleteFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a single Feature. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteFeatureRequest): + The request object. Request message for + ``FeaturestoreService.DeleteFeature``. + name (str): + Required. The name of the Features to be deleted. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.DeleteFeatureRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.DeleteFeatureRequest): + request = featurestore_service.DeleteFeatureRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_feature] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def import_feature_values(self, + request: featurestore_service.ImportFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Imports Feature values into the Featurestore from a + source storage. + The progress of the import is tracked by the returned + operation. The imported features are guaranteed to be + visible to subsequent read operations after the + operation is marked as successfully done. + If an import operation fails, the Feature values + returned from reads and exports may be inconsistent. If + consistency is required, the caller must retry the same + import request again and wait till the new operation + returned is marked as successfully done. + There are also scenarios where the caller can cause + inconsistency. + - Source data for import contains multiple distinct + Feature values for the same entity ID and timestamp. + - Source is modified during an import. This includes + adding, updating, or removing source data and/or + metadata. Examples of updating metadata include but are + not limited to changing storage location, storage class, + or retention policy. + - Online serving cluster is under-provisioned. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ImportFeatureValuesRequest): + The request object. Request message for + ``FeaturestoreService.ImportFeatureValues``. + entity_type (str): + Required. The resource name of the EntityType grouping + the Features for which values are being imported. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}`` + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.ImportFeatureValuesResponse` + Response message for + ``FeaturestoreService.ImportFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.ImportFeatureValuesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.ImportFeatureValuesRequest): + request = featurestore_service.ImportFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.import_feature_values] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('entity_type', request.entity_type), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + featurestore_service.ImportFeatureValuesResponse, + metadata_type=featurestore_service.ImportFeatureValuesOperationMetadata, + ) + + # Done; return the response. + return response + + def batch_read_feature_values(self, + request: featurestore_service.BatchReadFeatureValuesRequest = None, + *, + featurestore: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Batch reads Feature values from a Featurestore. + This API enables batch reading Feature values, where + each read instance in the batch may read Feature values + of entities from one or more EntityTypes. Point-in-time + correctness is guaranteed for Feature values of each + read instance as of each instance's read timestamp. + + Args: + request (google.cloud.aiplatform_v1beta1.types.BatchReadFeatureValuesRequest): + The request object. Request message for + ``FeaturestoreService.BatchReadFeatureValues``. + featurestore (str): + Required. The resource name of the Featurestore from + which to query Feature values. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + + This corresponds to the ``featurestore`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.BatchReadFeatureValuesResponse` + Response message for + ``FeaturestoreService.BatchReadFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([featurestore]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.BatchReadFeatureValuesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.BatchReadFeatureValuesRequest): + request = featurestore_service.BatchReadFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if featurestore is not None: + request.featurestore = featurestore + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_read_feature_values] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('featurestore', request.featurestore), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + featurestore_service.BatchReadFeatureValuesResponse, + metadata_type=featurestore_service.BatchReadFeatureValuesOperationMetadata, + ) + + # Done; return the response. + return response + + def search_features(self, + request: featurestore_service.SearchFeaturesRequest = None, + *, + location: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchFeaturesPager: + r"""Searches Features matching a query in a given + project. + + Args: + request (google.cloud.aiplatform_v1beta1.types.SearchFeaturesRequest): + The request object. Request message for + ``FeaturestoreService.SearchFeatures``. + location (str): + Required. The resource name of the Location to search + Features. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``location`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.featurestore_service.pagers.SearchFeaturesPager: + Response message for + ``FeaturestoreService.SearchFeatures``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([location]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.SearchFeaturesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.SearchFeaturesRequest): + request = featurestore_service.SearchFeaturesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if location is not None: + request.location = location + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.search_features] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('location', request.location), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.SearchFeaturesPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'FeaturestoreServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py new file mode 100644 index 0000000000..7baa8e920c --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py @@ -0,0 +1,511 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional + +from google.cloud.aiplatform_v1beta1.types import entity_type +from google.cloud.aiplatform_v1beta1.types import feature +from google.cloud.aiplatform_v1beta1.types import featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore_service + + +class ListFeaturestoresPager: + """A pager for iterating through ``list_featurestores`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse` object, and + provides an ``__iter__`` method to iterate through its + ``featurestores`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListFeaturestores`` requests and continue to iterate + through the ``featurestores`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., featurestore_service.ListFeaturestoresResponse], + request: featurestore_service.ListFeaturestoresRequest, + response: featurestore_service.ListFeaturestoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListFeaturestoresRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.ListFeaturestoresRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[featurestore_service.ListFeaturestoresResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[featurestore.Featurestore]: + for page in self.pages: + yield from page.featurestores + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListFeaturestoresAsyncPager: + """A pager for iterating through ``list_featurestores`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``featurestores`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListFeaturestores`` requests and continue to iterate + through the ``featurestores`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[featurestore_service.ListFeaturestoresResponse]], + request: featurestore_service.ListFeaturestoresRequest, + response: featurestore_service.ListFeaturestoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListFeaturestoresRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.ListFeaturestoresRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[featurestore_service.ListFeaturestoresResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[featurestore.Featurestore]: + async def async_generator(): + async for page in self.pages: + for response in page.featurestores: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListEntityTypesPager: + """A pager for iterating through ``list_entity_types`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``entity_types`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListEntityTypes`` requests and continue to iterate + through the ``entity_types`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., featurestore_service.ListEntityTypesResponse], + request: featurestore_service.ListEntityTypesRequest, + response: featurestore_service.ListEntityTypesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListEntityTypesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.ListEntityTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[featurestore_service.ListEntityTypesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[entity_type.EntityType]: + for page in self.pages: + yield from page.entity_types + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListEntityTypesAsyncPager: + """A pager for iterating through ``list_entity_types`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``entity_types`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListEntityTypes`` requests and continue to iterate + through the ``entity_types`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[featurestore_service.ListEntityTypesResponse]], + request: featurestore_service.ListEntityTypesRequest, + response: featurestore_service.ListEntityTypesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListEntityTypesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.ListEntityTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[featurestore_service.ListEntityTypesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[entity_type.EntityType]: + async def async_generator(): + async for page in self.pages: + for response in page.entity_types: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListFeaturesPager: + """A pager for iterating through ``list_features`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``features`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListFeatures`` requests and continue to iterate + through the ``features`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., featurestore_service.ListFeaturesResponse], + request: featurestore_service.ListFeaturesRequest, + response: featurestore_service.ListFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListFeaturesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.ListFeaturesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[featurestore_service.ListFeaturesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[feature.Feature]: + for page in self.pages: + yield from page.features + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListFeaturesAsyncPager: + """A pager for iterating through ``list_features`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``features`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListFeatures`` requests and continue to iterate + through the ``features`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[featurestore_service.ListFeaturesResponse]], + request: featurestore_service.ListFeaturesRequest, + response: featurestore_service.ListFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListFeaturesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.ListFeaturesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[featurestore_service.ListFeaturesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[feature.Feature]: + async def async_generator(): + async for page in self.pages: + for response in page.features: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class SearchFeaturesPager: + """A pager for iterating through ``search_features`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``features`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``SearchFeatures`` requests and continue to iterate + through the ``features`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., featurestore_service.SearchFeaturesResponse], + request: featurestore_service.SearchFeaturesRequest, + response: featurestore_service.SearchFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.SearchFeaturesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.SearchFeaturesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[featurestore_service.SearchFeaturesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[feature.Feature]: + for page in self.pages: + yield from page.features + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class SearchFeaturesAsyncPager: + """A pager for iterating through ``search_features`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``features`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``SearchFeatures`` requests and continue to iterate + through the ``features`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[featurestore_service.SearchFeaturesResponse]], + request: featurestore_service.SearchFeaturesRequest, + response: featurestore_service.SearchFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.SearchFeaturesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = featurestore_service.SearchFeaturesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[featurestore_service.SearchFeaturesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[feature.Feature]: + async def async_generator(): + async for page in self.pages: + for response in page.features: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py new file mode 100644 index 0000000000..3fdc8aa3df --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import FeaturestoreServiceTransport +from .grpc import FeaturestoreServiceGrpcTransport +from .grpc_asyncio import FeaturestoreServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreServiceTransport]] +_transport_registry['grpc'] = FeaturestoreServiceGrpcTransport +_transport_registry['grpc_asyncio'] = FeaturestoreServiceGrpcAsyncIOTransport + +__all__ = ( + 'FeaturestoreServiceTransport', + 'FeaturestoreServiceGrpcTransport', + 'FeaturestoreServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py new file mode 100644 index 0000000000..4adf29e11b --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py @@ -0,0 +1,391 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import entity_type +from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type +from google.cloud.aiplatform_v1beta1.types import feature +from google.cloud.aiplatform_v1beta1.types import feature as gca_feature +from google.cloud.aiplatform_v1beta1.types import featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class FeaturestoreServiceTransport(abc.ABC): + """Abstract transport class for FeaturestoreService.""" + + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) + + def __init__( + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + + # Save the credentials. + self._credentials = credentials + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_featurestore: gapic_v1.method.wrap_method( + self.create_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.get_featurestore: gapic_v1.method.wrap_method( + self.get_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.list_featurestores: gapic_v1.method.wrap_method( + self.list_featurestores, + default_timeout=None, + client_info=client_info, + ), + self.update_featurestore: gapic_v1.method.wrap_method( + self.update_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.delete_featurestore: gapic_v1.method.wrap_method( + self.delete_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.create_entity_type: gapic_v1.method.wrap_method( + self.create_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.get_entity_type: gapic_v1.method.wrap_method( + self.get_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.list_entity_types: gapic_v1.method.wrap_method( + self.list_entity_types, + default_timeout=None, + client_info=client_info, + ), + self.update_entity_type: gapic_v1.method.wrap_method( + self.update_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.delete_entity_type: gapic_v1.method.wrap_method( + self.delete_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.create_feature: gapic_v1.method.wrap_method( + self.create_feature, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_features: gapic_v1.method.wrap_method( + self.batch_create_features, + default_timeout=None, + client_info=client_info, + ), + self.get_feature: gapic_v1.method.wrap_method( + self.get_feature, + default_timeout=None, + client_info=client_info, + ), + self.list_features: gapic_v1.method.wrap_method( + self.list_features, + default_timeout=None, + client_info=client_info, + ), + self.update_feature: gapic_v1.method.wrap_method( + self.update_feature, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature: gapic_v1.method.wrap_method( + self.delete_feature, + default_timeout=None, + client_info=client_info, + ), + self.import_feature_values: gapic_v1.method.wrap_method( + self.import_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.batch_read_feature_values: gapic_v1.method.wrap_method( + self.batch_read_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.search_features: gapic_v1.method.wrap_method( + self.search_features, + default_timeout=None, + client_info=client_info, + ), + + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_featurestore(self) -> typing.Callable[ + [featurestore_service.CreateFeaturestoreRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def get_featurestore(self) -> typing.Callable[ + [featurestore_service.GetFeaturestoreRequest], + typing.Union[ + featurestore.Featurestore, + typing.Awaitable[featurestore.Featurestore] + ]]: + raise NotImplementedError() + + @property + def list_featurestores(self) -> typing.Callable[ + [featurestore_service.ListFeaturestoresRequest], + typing.Union[ + featurestore_service.ListFeaturestoresResponse, + typing.Awaitable[featurestore_service.ListFeaturestoresResponse] + ]]: + raise NotImplementedError() + + @property + def update_featurestore(self) -> typing.Callable[ + [featurestore_service.UpdateFeaturestoreRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def delete_featurestore(self) -> typing.Callable[ + [featurestore_service.DeleteFeaturestoreRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def create_entity_type(self) -> typing.Callable[ + [featurestore_service.CreateEntityTypeRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def get_entity_type(self) -> typing.Callable[ + [featurestore_service.GetEntityTypeRequest], + typing.Union[ + entity_type.EntityType, + typing.Awaitable[entity_type.EntityType] + ]]: + raise NotImplementedError() + + @property + def list_entity_types(self) -> typing.Callable[ + [featurestore_service.ListEntityTypesRequest], + typing.Union[ + featurestore_service.ListEntityTypesResponse, + typing.Awaitable[featurestore_service.ListEntityTypesResponse] + ]]: + raise NotImplementedError() + + @property + def update_entity_type(self) -> typing.Callable[ + [featurestore_service.UpdateEntityTypeRequest], + typing.Union[ + gca_entity_type.EntityType, + typing.Awaitable[gca_entity_type.EntityType] + ]]: + raise NotImplementedError() + + @property + def delete_entity_type(self) -> typing.Callable[ + [featurestore_service.DeleteEntityTypeRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def create_feature(self) -> typing.Callable[ + [featurestore_service.CreateFeatureRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def batch_create_features(self) -> typing.Callable[ + [featurestore_service.BatchCreateFeaturesRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def get_feature(self) -> typing.Callable[ + [featurestore_service.GetFeatureRequest], + typing.Union[ + feature.Feature, + typing.Awaitable[feature.Feature] + ]]: + raise NotImplementedError() + + @property + def list_features(self) -> typing.Callable[ + [featurestore_service.ListFeaturesRequest], + typing.Union[ + featurestore_service.ListFeaturesResponse, + typing.Awaitable[featurestore_service.ListFeaturesResponse] + ]]: + raise NotImplementedError() + + @property + def update_feature(self) -> typing.Callable[ + [featurestore_service.UpdateFeatureRequest], + typing.Union[ + gca_feature.Feature, + typing.Awaitable[gca_feature.Feature] + ]]: + raise NotImplementedError() + + @property + def delete_feature(self) -> typing.Callable[ + [featurestore_service.DeleteFeatureRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def import_feature_values(self) -> typing.Callable[ + [featurestore_service.ImportFeatureValuesRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def batch_read_feature_values(self) -> typing.Callable[ + [featurestore_service.BatchReadFeatureValuesRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def search_features(self) -> typing.Callable[ + [featurestore_service.SearchFeaturesRequest], + typing.Union[ + featurestore_service.SearchFeaturesResponse, + typing.Awaitable[featurestore_service.SearchFeaturesResponse] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'FeaturestoreServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py new file mode 100644 index 0000000000..48fb007e78 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py @@ -0,0 +1,772 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import entity_type +from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type +from google.cloud.aiplatform_v1beta1.types import feature +from google.cloud.aiplatform_v1beta1.types import feature as gca_feature +from google.cloud.aiplatform_v1beta1.types import featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import FeaturestoreServiceTransport, DEFAULT_CLIENT_INFO + + +class FeaturestoreServiceGrpcTransport(FeaturestoreServiceTransport): + """gRPC backend transport for FeaturestoreService. + + The service that handles CRUD and List for resources for + Featurestore. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_featurestore(self) -> Callable[ + [featurestore_service.CreateFeaturestoreRequest], + operations.Operation]: + r"""Return a callable for the create featurestore method over gRPC. + + Creates a new Featurestore in a given project and + location. + + Returns: + Callable[[~.CreateFeaturestoreRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_featurestore' not in self._stubs: + self._stubs['create_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeaturestore', + request_serializer=featurestore_service.CreateFeaturestoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_featurestore'] + + @property + def get_featurestore(self) -> Callable[ + [featurestore_service.GetFeaturestoreRequest], + featurestore.Featurestore]: + r"""Return a callable for the get featurestore method over gRPC. + + Gets details of a single Featurestore. + + Returns: + Callable[[~.GetFeaturestoreRequest], + ~.Featurestore]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_featurestore' not in self._stubs: + self._stubs['get_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeaturestore', + request_serializer=featurestore_service.GetFeaturestoreRequest.serialize, + response_deserializer=featurestore.Featurestore.deserialize, + ) + return self._stubs['get_featurestore'] + + @property + def list_featurestores(self) -> Callable[ + [featurestore_service.ListFeaturestoresRequest], + featurestore_service.ListFeaturestoresResponse]: + r"""Return a callable for the list featurestores method over gRPC. + + Lists Featurestores in a given project and location. + + Returns: + Callable[[~.ListFeaturestoresRequest], + ~.ListFeaturestoresResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_featurestores' not in self._stubs: + self._stubs['list_featurestores'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeaturestores', + request_serializer=featurestore_service.ListFeaturestoresRequest.serialize, + response_deserializer=featurestore_service.ListFeaturestoresResponse.deserialize, + ) + return self._stubs['list_featurestores'] + + @property + def update_featurestore(self) -> Callable[ + [featurestore_service.UpdateFeaturestoreRequest], + operations.Operation]: + r"""Return a callable for the update featurestore method over gRPC. + + Updates the parameters of a single Featurestore. + + Returns: + Callable[[~.UpdateFeaturestoreRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_featurestore' not in self._stubs: + self._stubs['update_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeaturestore', + request_serializer=featurestore_service.UpdateFeaturestoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['update_featurestore'] + + @property + def delete_featurestore(self) -> Callable[ + [featurestore_service.DeleteFeaturestoreRequest], + operations.Operation]: + r"""Return a callable for the delete featurestore method over gRPC. + + Deletes a single Featurestore. The Featurestore must not contain + any EntityTypes or ``force`` must be set to true for the request + to succeed. + + Returns: + Callable[[~.DeleteFeaturestoreRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_featurestore' not in self._stubs: + self._stubs['delete_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeaturestore', + request_serializer=featurestore_service.DeleteFeaturestoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_featurestore'] + + @property + def create_entity_type(self) -> Callable[ + [featurestore_service.CreateEntityTypeRequest], + operations.Operation]: + r"""Return a callable for the create entity type method over gRPC. + + Creates a new EntityType in a given Featurestore. + + Returns: + Callable[[~.CreateEntityTypeRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_entity_type' not in self._stubs: + self._stubs['create_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateEntityType', + request_serializer=featurestore_service.CreateEntityTypeRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_entity_type'] + + @property + def get_entity_type(self) -> Callable[ + [featurestore_service.GetEntityTypeRequest], + entity_type.EntityType]: + r"""Return a callable for the get entity type method over gRPC. + + Gets details of a single EntityType. + + Returns: + Callable[[~.GetEntityTypeRequest], + ~.EntityType]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_entity_type' not in self._stubs: + self._stubs['get_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetEntityType', + request_serializer=featurestore_service.GetEntityTypeRequest.serialize, + response_deserializer=entity_type.EntityType.deserialize, + ) + return self._stubs['get_entity_type'] + + @property + def list_entity_types(self) -> Callable[ + [featurestore_service.ListEntityTypesRequest], + featurestore_service.ListEntityTypesResponse]: + r"""Return a callable for the list entity types method over gRPC. + + Lists EntityTypes in a given Featurestore. + + Returns: + Callable[[~.ListEntityTypesRequest], + ~.ListEntityTypesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_entity_types' not in self._stubs: + self._stubs['list_entity_types'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListEntityTypes', + request_serializer=featurestore_service.ListEntityTypesRequest.serialize, + response_deserializer=featurestore_service.ListEntityTypesResponse.deserialize, + ) + return self._stubs['list_entity_types'] + + @property + def update_entity_type(self) -> Callable[ + [featurestore_service.UpdateEntityTypeRequest], + gca_entity_type.EntityType]: + r"""Return a callable for the update entity type method over gRPC. + + Updates the parameters of a single EntityType. + + Returns: + Callable[[~.UpdateEntityTypeRequest], + ~.EntityType]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_entity_type' not in self._stubs: + self._stubs['update_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateEntityType', + request_serializer=featurestore_service.UpdateEntityTypeRequest.serialize, + response_deserializer=gca_entity_type.EntityType.deserialize, + ) + return self._stubs['update_entity_type'] + + @property + def delete_entity_type(self) -> Callable[ + [featurestore_service.DeleteEntityTypeRequest], + operations.Operation]: + r"""Return a callable for the delete entity type method over gRPC. + + Deletes a single EntityType. The EntityType must not have any + Features or ``force`` must be set to true for the request to + succeed. + + Returns: + Callable[[~.DeleteEntityTypeRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_entity_type' not in self._stubs: + self._stubs['delete_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteEntityType', + request_serializer=featurestore_service.DeleteEntityTypeRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_entity_type'] + + @property + def create_feature(self) -> Callable[ + [featurestore_service.CreateFeatureRequest], + operations.Operation]: + r"""Return a callable for the create feature method over gRPC. + + Creates a new Feature in a given EntityType. + + Returns: + Callable[[~.CreateFeatureRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_feature' not in self._stubs: + self._stubs['create_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeature', + request_serializer=featurestore_service.CreateFeatureRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_feature'] + + @property + def batch_create_features(self) -> Callable[ + [featurestore_service.BatchCreateFeaturesRequest], + operations.Operation]: + r"""Return a callable for the batch create features method over gRPC. + + Creates a batch of Features in a given EntityType. + + Returns: + Callable[[~.BatchCreateFeaturesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_create_features' not in self._stubs: + self._stubs['batch_create_features'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchCreateFeatures', + request_serializer=featurestore_service.BatchCreateFeaturesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['batch_create_features'] + + @property + def get_feature(self) -> Callable[ + [featurestore_service.GetFeatureRequest], + feature.Feature]: + r"""Return a callable for the get feature method over gRPC. + + Gets details of a single Feature. + + Returns: + Callable[[~.GetFeatureRequest], + ~.Feature]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_feature' not in self._stubs: + self._stubs['get_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeature', + request_serializer=featurestore_service.GetFeatureRequest.serialize, + response_deserializer=feature.Feature.deserialize, + ) + return self._stubs['get_feature'] + + @property + def list_features(self) -> Callable[ + [featurestore_service.ListFeaturesRequest], + featurestore_service.ListFeaturesResponse]: + r"""Return a callable for the list features method over gRPC. + + Lists Features in a given EntityType. + + Returns: + Callable[[~.ListFeaturesRequest], + ~.ListFeaturesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_features' not in self._stubs: + self._stubs['list_features'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeatures', + request_serializer=featurestore_service.ListFeaturesRequest.serialize, + response_deserializer=featurestore_service.ListFeaturesResponse.deserialize, + ) + return self._stubs['list_features'] + + @property + def update_feature(self) -> Callable[ + [featurestore_service.UpdateFeatureRequest], + gca_feature.Feature]: + r"""Return a callable for the update feature method over gRPC. + + Updates the parameters of a single Feature. + + Returns: + Callable[[~.UpdateFeatureRequest], + ~.Feature]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_feature' not in self._stubs: + self._stubs['update_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeature', + request_serializer=featurestore_service.UpdateFeatureRequest.serialize, + response_deserializer=gca_feature.Feature.deserialize, + ) + return self._stubs['update_feature'] + + @property + def delete_feature(self) -> Callable[ + [featurestore_service.DeleteFeatureRequest], + operations.Operation]: + r"""Return a callable for the delete feature method over gRPC. + + Deletes a single Feature. + + Returns: + Callable[[~.DeleteFeatureRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_feature' not in self._stubs: + self._stubs['delete_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeature', + request_serializer=featurestore_service.DeleteFeatureRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_feature'] + + @property + def import_feature_values(self) -> Callable[ + [featurestore_service.ImportFeatureValuesRequest], + operations.Operation]: + r"""Return a callable for the import feature values method over gRPC. + + Imports Feature values into the Featurestore from a + source storage. + The progress of the import is tracked by the returned + operation. The imported features are guaranteed to be + visible to subsequent read operations after the + operation is marked as successfully done. + If an import operation fails, the Feature values + returned from reads and exports may be inconsistent. If + consistency is required, the caller must retry the same + import request again and wait till the new operation + returned is marked as successfully done. + There are also scenarios where the caller can cause + inconsistency. + - Source data for import contains multiple distinct + Feature values for the same entity ID and timestamp. + - Source is modified during an import. This includes + adding, updating, or removing source data and/or + metadata. Examples of updating metadata include but are + not limited to changing storage location, storage class, + or retention policy. + - Online serving cluster is under-provisioned. + + Returns: + Callable[[~.ImportFeatureValuesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'import_feature_values' not in self._stubs: + self._stubs['import_feature_values'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ImportFeatureValues', + request_serializer=featurestore_service.ImportFeatureValuesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['import_feature_values'] + + @property + def batch_read_feature_values(self) -> Callable[ + [featurestore_service.BatchReadFeatureValuesRequest], + operations.Operation]: + r"""Return a callable for the batch read feature values method over gRPC. + + Batch reads Feature values from a Featurestore. + This API enables batch reading Feature values, where + each read instance in the batch may read Feature values + of entities from one or more EntityTypes. Point-in-time + correctness is guaranteed for Feature values of each + read instance as of each instance's read timestamp. + + Returns: + Callable[[~.BatchReadFeatureValuesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_read_feature_values' not in self._stubs: + self._stubs['batch_read_feature_values'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchReadFeatureValues', + request_serializer=featurestore_service.BatchReadFeatureValuesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['batch_read_feature_values'] + + @property + def search_features(self) -> Callable[ + [featurestore_service.SearchFeaturesRequest], + featurestore_service.SearchFeaturesResponse]: + r"""Return a callable for the search features method over gRPC. + + Searches Features matching a query in a given + project. + + Returns: + Callable[[~.SearchFeaturesRequest], + ~.SearchFeaturesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'search_features' not in self._stubs: + self._stubs['search_features'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/SearchFeatures', + request_serializer=featurestore_service.SearchFeaturesRequest.serialize, + response_deserializer=featurestore_service.SearchFeaturesResponse.deserialize, + ) + return self._stubs['search_features'] + + +__all__ = ( + 'FeaturestoreServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..97114a68be --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py @@ -0,0 +1,777 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import entity_type +from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type +from google.cloud.aiplatform_v1beta1.types import feature +from google.cloud.aiplatform_v1beta1.types import feature as gca_feature +from google.cloud.aiplatform_v1beta1.types import featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import FeaturestoreServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import FeaturestoreServiceGrpcTransport + + +class FeaturestoreServiceGrpcAsyncIOTransport(FeaturestoreServiceTransport): + """gRPC AsyncIO backend transport for FeaturestoreService. + + The service that handles CRUD and List for resources for + Featurestore. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_featurestore(self) -> Callable[ + [featurestore_service.CreateFeaturestoreRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create featurestore method over gRPC. + + Creates a new Featurestore in a given project and + location. + + Returns: + Callable[[~.CreateFeaturestoreRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_featurestore' not in self._stubs: + self._stubs['create_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeaturestore', + request_serializer=featurestore_service.CreateFeaturestoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_featurestore'] + + @property + def get_featurestore(self) -> Callable[ + [featurestore_service.GetFeaturestoreRequest], + Awaitable[featurestore.Featurestore]]: + r"""Return a callable for the get featurestore method over gRPC. + + Gets details of a single Featurestore. + + Returns: + Callable[[~.GetFeaturestoreRequest], + Awaitable[~.Featurestore]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_featurestore' not in self._stubs: + self._stubs['get_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeaturestore', + request_serializer=featurestore_service.GetFeaturestoreRequest.serialize, + response_deserializer=featurestore.Featurestore.deserialize, + ) + return self._stubs['get_featurestore'] + + @property + def list_featurestores(self) -> Callable[ + [featurestore_service.ListFeaturestoresRequest], + Awaitable[featurestore_service.ListFeaturestoresResponse]]: + r"""Return a callable for the list featurestores method over gRPC. + + Lists Featurestores in a given project and location. + + Returns: + Callable[[~.ListFeaturestoresRequest], + Awaitable[~.ListFeaturestoresResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_featurestores' not in self._stubs: + self._stubs['list_featurestores'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeaturestores', + request_serializer=featurestore_service.ListFeaturestoresRequest.serialize, + response_deserializer=featurestore_service.ListFeaturestoresResponse.deserialize, + ) + return self._stubs['list_featurestores'] + + @property + def update_featurestore(self) -> Callable[ + [featurestore_service.UpdateFeaturestoreRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the update featurestore method over gRPC. + + Updates the parameters of a single Featurestore. + + Returns: + Callable[[~.UpdateFeaturestoreRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_featurestore' not in self._stubs: + self._stubs['update_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeaturestore', + request_serializer=featurestore_service.UpdateFeaturestoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['update_featurestore'] + + @property + def delete_featurestore(self) -> Callable[ + [featurestore_service.DeleteFeaturestoreRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete featurestore method over gRPC. + + Deletes a single Featurestore. The Featurestore must not contain + any EntityTypes or ``force`` must be set to true for the request + to succeed. + + Returns: + Callable[[~.DeleteFeaturestoreRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_featurestore' not in self._stubs: + self._stubs['delete_featurestore'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeaturestore', + request_serializer=featurestore_service.DeleteFeaturestoreRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_featurestore'] + + @property + def create_entity_type(self) -> Callable[ + [featurestore_service.CreateEntityTypeRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create entity type method over gRPC. + + Creates a new EntityType in a given Featurestore. + + Returns: + Callable[[~.CreateEntityTypeRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_entity_type' not in self._stubs: + self._stubs['create_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateEntityType', + request_serializer=featurestore_service.CreateEntityTypeRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_entity_type'] + + @property + def get_entity_type(self) -> Callable[ + [featurestore_service.GetEntityTypeRequest], + Awaitable[entity_type.EntityType]]: + r"""Return a callable for the get entity type method over gRPC. + + Gets details of a single EntityType. + + Returns: + Callable[[~.GetEntityTypeRequest], + Awaitable[~.EntityType]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_entity_type' not in self._stubs: + self._stubs['get_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetEntityType', + request_serializer=featurestore_service.GetEntityTypeRequest.serialize, + response_deserializer=entity_type.EntityType.deserialize, + ) + return self._stubs['get_entity_type'] + + @property + def list_entity_types(self) -> Callable[ + [featurestore_service.ListEntityTypesRequest], + Awaitable[featurestore_service.ListEntityTypesResponse]]: + r"""Return a callable for the list entity types method over gRPC. + + Lists EntityTypes in a given Featurestore. + + Returns: + Callable[[~.ListEntityTypesRequest], + Awaitable[~.ListEntityTypesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_entity_types' not in self._stubs: + self._stubs['list_entity_types'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListEntityTypes', + request_serializer=featurestore_service.ListEntityTypesRequest.serialize, + response_deserializer=featurestore_service.ListEntityTypesResponse.deserialize, + ) + return self._stubs['list_entity_types'] + + @property + def update_entity_type(self) -> Callable[ + [featurestore_service.UpdateEntityTypeRequest], + Awaitable[gca_entity_type.EntityType]]: + r"""Return a callable for the update entity type method over gRPC. + + Updates the parameters of a single EntityType. + + Returns: + Callable[[~.UpdateEntityTypeRequest], + Awaitable[~.EntityType]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_entity_type' not in self._stubs: + self._stubs['update_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateEntityType', + request_serializer=featurestore_service.UpdateEntityTypeRequest.serialize, + response_deserializer=gca_entity_type.EntityType.deserialize, + ) + return self._stubs['update_entity_type'] + + @property + def delete_entity_type(self) -> Callable[ + [featurestore_service.DeleteEntityTypeRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete entity type method over gRPC. + + Deletes a single EntityType. The EntityType must not have any + Features or ``force`` must be set to true for the request to + succeed. + + Returns: + Callable[[~.DeleteEntityTypeRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_entity_type' not in self._stubs: + self._stubs['delete_entity_type'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteEntityType', + request_serializer=featurestore_service.DeleteEntityTypeRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_entity_type'] + + @property + def create_feature(self) -> Callable[ + [featurestore_service.CreateFeatureRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create feature method over gRPC. + + Creates a new Feature in a given EntityType. + + Returns: + Callable[[~.CreateFeatureRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_feature' not in self._stubs: + self._stubs['create_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeature', + request_serializer=featurestore_service.CreateFeatureRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_feature'] + + @property + def batch_create_features(self) -> Callable[ + [featurestore_service.BatchCreateFeaturesRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the batch create features method over gRPC. + + Creates a batch of Features in a given EntityType. + + Returns: + Callable[[~.BatchCreateFeaturesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_create_features' not in self._stubs: + self._stubs['batch_create_features'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchCreateFeatures', + request_serializer=featurestore_service.BatchCreateFeaturesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['batch_create_features'] + + @property + def get_feature(self) -> Callable[ + [featurestore_service.GetFeatureRequest], + Awaitable[feature.Feature]]: + r"""Return a callable for the get feature method over gRPC. + + Gets details of a single Feature. + + Returns: + Callable[[~.GetFeatureRequest], + Awaitable[~.Feature]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_feature' not in self._stubs: + self._stubs['get_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeature', + request_serializer=featurestore_service.GetFeatureRequest.serialize, + response_deserializer=feature.Feature.deserialize, + ) + return self._stubs['get_feature'] + + @property + def list_features(self) -> Callable[ + [featurestore_service.ListFeaturesRequest], + Awaitable[featurestore_service.ListFeaturesResponse]]: + r"""Return a callable for the list features method over gRPC. + + Lists Features in a given EntityType. + + Returns: + Callable[[~.ListFeaturesRequest], + Awaitable[~.ListFeaturesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_features' not in self._stubs: + self._stubs['list_features'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeatures', + request_serializer=featurestore_service.ListFeaturesRequest.serialize, + response_deserializer=featurestore_service.ListFeaturesResponse.deserialize, + ) + return self._stubs['list_features'] + + @property + def update_feature(self) -> Callable[ + [featurestore_service.UpdateFeatureRequest], + Awaitable[gca_feature.Feature]]: + r"""Return a callable for the update feature method over gRPC. + + Updates the parameters of a single Feature. + + Returns: + Callable[[~.UpdateFeatureRequest], + Awaitable[~.Feature]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_feature' not in self._stubs: + self._stubs['update_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeature', + request_serializer=featurestore_service.UpdateFeatureRequest.serialize, + response_deserializer=gca_feature.Feature.deserialize, + ) + return self._stubs['update_feature'] + + @property + def delete_feature(self) -> Callable[ + [featurestore_service.DeleteFeatureRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete feature method over gRPC. + + Deletes a single Feature. + + Returns: + Callable[[~.DeleteFeatureRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_feature' not in self._stubs: + self._stubs['delete_feature'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeature', + request_serializer=featurestore_service.DeleteFeatureRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_feature'] + + @property + def import_feature_values(self) -> Callable[ + [featurestore_service.ImportFeatureValuesRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the import feature values method over gRPC. + + Imports Feature values into the Featurestore from a + source storage. + The progress of the import is tracked by the returned + operation. The imported features are guaranteed to be + visible to subsequent read operations after the + operation is marked as successfully done. + If an import operation fails, the Feature values + returned from reads and exports may be inconsistent. If + consistency is required, the caller must retry the same + import request again and wait till the new operation + returned is marked as successfully done. + There are also scenarios where the caller can cause + inconsistency. + - Source data for import contains multiple distinct + Feature values for the same entity ID and timestamp. + - Source is modified during an import. This includes + adding, updating, or removing source data and/or + metadata. Examples of updating metadata include but are + not limited to changing storage location, storage class, + or retention policy. + - Online serving cluster is under-provisioned. + + Returns: + Callable[[~.ImportFeatureValuesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'import_feature_values' not in self._stubs: + self._stubs['import_feature_values'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ImportFeatureValues', + request_serializer=featurestore_service.ImportFeatureValuesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['import_feature_values'] + + @property + def batch_read_feature_values(self) -> Callable[ + [featurestore_service.BatchReadFeatureValuesRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the batch read feature values method over gRPC. + + Batch reads Feature values from a Featurestore. + This API enables batch reading Feature values, where + each read instance in the batch may read Feature values + of entities from one or more EntityTypes. Point-in-time + correctness is guaranteed for Feature values of each + read instance as of each instance's read timestamp. + + Returns: + Callable[[~.BatchReadFeatureValuesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_read_feature_values' not in self._stubs: + self._stubs['batch_read_feature_values'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchReadFeatureValues', + request_serializer=featurestore_service.BatchReadFeatureValuesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['batch_read_feature_values'] + + @property + def search_features(self) -> Callable[ + [featurestore_service.SearchFeaturesRequest], + Awaitable[featurestore_service.SearchFeaturesResponse]]: + r"""Return a callable for the search features method over gRPC. + + Searches Features matching a query in a given + project. + + Returns: + Callable[[~.SearchFeaturesRequest], + Awaitable[~.SearchFeaturesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'search_features' not in self._stubs: + self._stubs['search_features'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.FeaturestoreService/SearchFeatures', + request_serializer=featurestore_service.SearchFeaturesRequest.serialize, + response_deserializer=featurestore_service.SearchFeaturesResponse.deserialize, + ) + return self._stubs['search_features'] + + +__all__ = ( + 'FeaturestoreServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 8b0e8331bb..d250cb7dfc 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index cb4d402b6a..de6c880e58 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job @@ -717,7 +717,7 @@ def delete_custom_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a CustomJob. Args: @@ -800,7 +800,7 @@ def delete_custom_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -1156,7 +1156,7 @@ def delete_data_labeling_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1239,7 +1239,7 @@ def delete_data_labeling_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -1589,7 +1589,7 @@ def delete_hyperparameter_tuning_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1672,7 +1672,7 @@ def delete_hyperparameter_tuning_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -2040,7 +2040,7 @@ def delete_batch_prediction_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2124,7 +2124,7 @@ def delete_batch_prediction_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -2591,7 +2591,7 @@ def update_model_deployment_monitoring_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Updates a ModelDeploymentMonitoringJob. Args: @@ -2673,7 +2673,7 @@ def update_model_deployment_monitoring_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob, @@ -2690,7 +2690,7 @@ def delete_model_deployment_monitoring_job(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a ModelDeploymentMonitoringJob. Args: @@ -2773,7 +2773,7 @@ def delete_model_deployment_monitoring_job(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index d47a250882..5b5275ba33 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers from google.cloud.aiplatform_v1beta1.types import artifact diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index e1fcc67567..06ca29cf5a 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers from google.cloud.aiplatform_v1beta1.types import artifact @@ -397,7 +397,7 @@ def create_metadata_store(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Initializes a MetadataStore, including allocation of resources. @@ -497,7 +497,7 @@ def create_metadata_store(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_metadata_store.MetadataStore, @@ -685,7 +685,7 @@ def delete_metadata_store(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a single MetadataStore. Args: @@ -768,7 +768,7 @@ def delete_metadata_store(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -1508,7 +1508,7 @@ def delete_context(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a stored Context. Args: @@ -1591,7 +1591,7 @@ def delete_context(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index 72cfd1e4e4..ac7b775897 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.types import deployed_model_ref diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 29e081bc10..224f714816 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.types import deployed_model_ref @@ -390,7 +390,7 @@ def upload_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -472,7 +472,7 @@ def upload_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, model_service.UploadModelResponse, @@ -743,7 +743,7 @@ def delete_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -828,7 +828,7 @@ def delete_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -846,7 +846,7 @@ def export_model(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -933,7 +933,7 @@ def export_model(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, model_service.ExportModelResponse, diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 6235697be1..50fffd8438 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.types import encryption_spec diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 07f1ac0444..aa99b2c0c3 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.types import encryption_spec @@ -633,7 +633,7 @@ def delete_training_pipeline(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -716,7 +716,7 @@ def delete_training_pipeline(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index a6de6886e7..27364bb495 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -28,7 +28,7 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.types import operation as gca_operation diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index 813d6413ff..caa4f9aa26 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -32,7 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -345,7 +345,7 @@ def create_specialist_pool(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -432,7 +432,7 @@ def create_specialist_pool(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_specialist_pool.SpecialistPool, @@ -628,7 +628,7 @@ def delete_specialist_pool(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -712,7 +712,7 @@ def delete_specialist_pool(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, empty.Empty, @@ -730,7 +730,7 @@ def update_specialist_pool(self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + ) -> gac_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -816,7 +816,7 @@ def update_specialist_pool(self, ) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = gac_operation.from_gapic( response, self._transport.operations_client, gca_specialist_pool.SpecialistPool, diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 8cc21f36ae..aff56b122d 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -100,6 +100,9 @@ UndeployModelResponse, UpdateEndpointRequest, ) +from .entity_type import ( + EntityType, +) from .env_var import ( EnvVar, ) @@ -126,18 +129,80 @@ from .explanation_metadata import ( ExplanationMetadata, ) +from .feature import ( + Feature, +) from .feature_monitoring_stats import ( FeatureStatsAnomaly, ) +from .feature_selector import ( + FeatureSelector, + IdMatcher, +) +from .featurestore import ( + Featurestore, +) +from .featurestore_monitoring import ( + FeaturestoreMonitoringConfig, +) +from .featurestore_online_service import ( + FeatureValue, + FeatureValueList, + ReadFeatureValuesRequest, + ReadFeatureValuesResponse, + ReadSetting, + StreamingReadFeatureValuesRequest, +) +from .featurestore_service import ( + BatchCreateFeaturesOperationMetadata, + BatchCreateFeaturesRequest, + BatchCreateFeaturesResponse, + BatchReadFeatureValuesOperationMetadata, + BatchReadFeatureValuesRequest, + BatchReadFeatureValuesResponse, + CreateEntityTypeOperationMetadata, + CreateEntityTypeRequest, + CreateFeatureOperationMetadata, + CreateFeatureRequest, + CreateFeaturestoreOperationMetadata, + CreateFeaturestoreRequest, + DeleteEntityTypeRequest, + DeleteFeatureRequest, + DeleteFeaturestoreRequest, + DestinationFeatureSetting, + FeatureValueDestination, + GetEntityTypeRequest, + GetFeatureRequest, + GetFeaturestoreRequest, + ImportFeatureValuesOperationMetadata, + ImportFeatureValuesRequest, + ImportFeatureValuesResponse, + ListEntityTypesRequest, + ListEntityTypesResponse, + ListFeaturesRequest, + ListFeaturesResponse, + ListFeaturestoresRequest, + ListFeaturestoresResponse, + SearchFeaturesRequest, + SearchFeaturesResponse, + UpdateEntityTypeRequest, + UpdateFeatureRequest, + UpdateFeaturestoreOperationMetadata, + UpdateFeaturestoreRequest, +) from .hyperparameter_tuning_job import ( HyperparameterTuningJob, ) from .io import ( + AvroSource, BigQueryDestination, BigQuerySource, ContainerRegistryDestination, + CsvDestination, + CsvSource, GcsDestination, GcsSource, + TFRecordDestination, ) from .job_service import ( CancelBatchPredictionJobRequest, @@ -336,6 +401,12 @@ TimestampSplit, TrainingPipeline, ) +from .types import ( + BoolArray, + DoubleArray, + Int64Array, + StringArray, +) from .user_action_reference import ( UserActionReference, ) @@ -421,6 +492,7 @@ 'UndeployModelRequest', 'UndeployModelResponse', 'UpdateEndpointRequest', + 'EntityType', 'EnvVar', 'Event', 'Execution', @@ -437,13 +509,63 @@ 'SmoothGradConfig', 'XraiAttribution', 'ExplanationMetadata', + 'Feature', 'FeatureStatsAnomaly', + 'FeatureSelector', + 'IdMatcher', + 'Featurestore', + 'FeaturestoreMonitoringConfig', + 'FeatureValue', + 'FeatureValueList', + 'ReadFeatureValuesRequest', + 'ReadFeatureValuesResponse', + 'ReadSetting', + 'StreamingReadFeatureValuesRequest', + 'BatchCreateFeaturesOperationMetadata', + 'BatchCreateFeaturesRequest', + 'BatchCreateFeaturesResponse', + 'BatchReadFeatureValuesOperationMetadata', + 'BatchReadFeatureValuesRequest', + 'BatchReadFeatureValuesResponse', + 'CreateEntityTypeOperationMetadata', + 'CreateEntityTypeRequest', + 'CreateFeatureOperationMetadata', + 'CreateFeatureRequest', + 'CreateFeaturestoreOperationMetadata', + 'CreateFeaturestoreRequest', + 'DeleteEntityTypeRequest', + 'DeleteFeatureRequest', + 'DeleteFeaturestoreRequest', + 'DestinationFeatureSetting', + 'FeatureValueDestination', + 'GetEntityTypeRequest', + 'GetFeatureRequest', + 'GetFeaturestoreRequest', + 'ImportFeatureValuesOperationMetadata', + 'ImportFeatureValuesRequest', + 'ImportFeatureValuesResponse', + 'ListEntityTypesRequest', + 'ListEntityTypesResponse', + 'ListFeaturesRequest', + 'ListFeaturesResponse', + 'ListFeaturestoresRequest', + 'ListFeaturestoresResponse', + 'SearchFeaturesRequest', + 'SearchFeaturesResponse', + 'UpdateEntityTypeRequest', + 'UpdateFeatureRequest', + 'UpdateFeaturestoreOperationMetadata', + 'UpdateFeaturestoreRequest', 'HyperparameterTuningJob', + 'AvroSource', 'BigQueryDestination', 'BigQuerySource', 'ContainerRegistryDestination', + 'CsvDestination', + 'CsvSource', 'GcsDestination', 'GcsSource', + 'TFRecordDestination', 'CancelBatchPredictionJobRequest', 'CancelCustomJobRequest', 'CancelDataLabelingJobRequest', @@ -599,6 +721,10 @@ 'PredefinedSplit', 'TimestampSplit', 'TrainingPipeline', + 'BoolArray', + 'DoubleArray', + 'Int64Array', + 'StringArray', 'UserActionReference', 'AddTrialMeasurementRequest', 'CheckTrialEarlyStoppingStateMetatdata', diff --git a/google/cloud/aiplatform_v1beta1/types/entity_type.py b/google/cloud/aiplatform_v1beta1/types/entity_type.py new file mode 100644 index 0000000000..38448a20c3 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/entity_type.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import featurestore_monitoring +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'EntityType', + }, +) + + +class EntityType(proto.Message): + r"""An entity type is a type of object in a system that needs to + be modeled and have stored information about. For example, + driver is an entity type, and driver0 is an instance of an + entity type driver. + + Attributes: + name (str): + Immutable. Name of the EntityType. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + The last part entity_type is assigned by the client. The + entity_type can be up to 64 characters long and can consist + only of ASCII Latin letters A-Z and a-z and underscore(_), + and ASCII digits 0-9 starting with a letter. The value will + be unique given a featurestore. + description (str): + Optional. Description of the EntityType. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this EntityType + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this EntityType + was most recently updated. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.EntityType.LabelsEntry]): + Optional. The labels with user-defined + metadata to organize your EntityTypes. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + on and examples of labels. No more than 64 user + labels can be associated with one EntityType + (System labels are excluded)." + System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + etag (str): + Optional. Used to perform a consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + monitoring_config (google.cloud.aiplatform_v1beta1.types.FeaturestoreMonitoringConfig): + If this is populated with + [FeaturestoreMonitoringConfig.monitoring_interval] + specified, snapshot analysis monitoring is enabled. + Otherwise, snapshot analysis monitoring is disabled. + """ + + name = proto.Field(proto.STRING, number=1) + + description = proto.Field(proto.STRING, number=2) + + create_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + + labels = proto.MapField(proto.STRING, proto.STRING, number=6) + + etag = proto.Field(proto.STRING, number=7) + + monitoring_config = proto.Field(proto.MESSAGE, number=8, + message=featurestore_monitoring.FeaturestoreMonitoringConfig, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/feature.py b/google/cloud/aiplatform_v1beta1/types/feature.py new file mode 100644 index 0000000000..f34a825fab --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/feature.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import feature_monitoring_stats +from google.cloud.aiplatform_v1beta1.types import featurestore_monitoring +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Feature', + }, +) + + +class Feature(proto.Message): + r"""Feature Metadata information that describes an attribute of + an entity type. For example, apple is an entity type, and color + is a feature that describes apple. + + Attributes: + name (str): + Immutable. Name of the Feature. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}`` + + The last part feature is assigned by the client. The feature + can be up to 64 characters long and can consist only of + ASCII Latin letters A-Z and a-z, underscore(_), and ASCII + digits 0-9 starting with a letter. The value will be unique + given an entity type. + description (str): + Description of the Feature. + value_type (google.cloud.aiplatform_v1beta1.types.Feature.ValueType): + Required. Immutable. Type of Feature value. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this EntityType + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this EntityType + was most recently updated. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Feature.LabelsEntry]): + Optional. The labels with user-defined + metadata to organize your Features. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + on and examples of labels. No more than 64 user + labels can be associated with one Feature + (System labels are excluded)." + System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + etag (str): + Used to perform a consistent read-modify- + rite updates. If not set, a blind "overwrite" + update happens. + monitoring_config (google.cloud.aiplatform_v1beta1.types.FeaturestoreMonitoringConfig): + If this is populated with + [FeaturestoreMonitoringConfig.disabled][] = true, snapshot + analysis monitoring is disabled; if + [FeaturestoreMonitoringConfig.monitoring_interval][] + specified, snapshot analysis monitoring is enabled. + Otherwise, snapshot analysis monitoring config is same as + the EntityType's this Feature belongs to. + monitoring_stats (Sequence[google.cloud.aiplatform_v1beta1.types.FeatureStatsAnomaly]): + Output only. A list of historical [Snapshot + Analysis][google.cloud.aiplatform.master.FeaturestoreMonitoringConfig.SnapshotAnalysis] + stats requested by user, sorted by + ``FeatureStatsAnomaly.start_time`` + descending. + """ + class ValueType(proto.Enum): + r"""An enum representing the value type of a feature.""" + VALUE_TYPE_UNSPECIFIED = 0 + BOOL = 1 + BOOL_ARRAY = 2 + DOUBLE = 3 + DOUBLE_ARRAY = 4 + INT64 = 9 + INT64_ARRAY = 10 + STRING = 11 + STRING_ARRAY = 12 + BYTES = 13 + + name = proto.Field(proto.STRING, number=1) + + description = proto.Field(proto.STRING, number=2) + + value_type = proto.Field(proto.ENUM, number=3, + enum=ValueType, + ) + + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) + + labels = proto.MapField(proto.STRING, proto.STRING, number=6) + + etag = proto.Field(proto.STRING, number=7) + + monitoring_config = proto.Field(proto.MESSAGE, number=9, + message=featurestore_monitoring.FeaturestoreMonitoringConfig, + ) + + monitoring_stats = proto.RepeatedField(proto.MESSAGE, number=10, + message=feature_monitoring_stats.FeatureStatsAnomaly, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/feature_selector.py b/google/cloud/aiplatform_v1beta1/types/feature_selector.py new file mode 100644 index 0000000000..346029f8f7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/feature_selector.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'IdMatcher', + 'FeatureSelector', + }, +) + + +class IdMatcher(proto.Message): + r"""Matcher for Features of an EntityType by Feature ID. + + Attributes: + ids (Sequence[str]): + Required. The following are accepted as ``ids``: + + - A single-element list containing only ``*``, which + selects all Features in the target EntityType, or + - A list containing only Feature IDs, which selects only + Features with those IDs in the target EntityType. + """ + + ids = proto.RepeatedField(proto.STRING, number=1) + + +class FeatureSelector(proto.Message): + r"""Selector for Features of an EntityType. + + Attributes: + id_matcher (google.cloud.aiplatform_v1beta1.types.IdMatcher): + Required. Matches Features based on ID. + """ + + id_matcher = proto.Field(proto.MESSAGE, number=1, + message='IdMatcher', + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore.py b/google/cloud/aiplatform_v1beta1/types/featurestore.py new file mode 100644 index 0000000000..588d587e2f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/featurestore.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Featurestore', + }, +) + + +class Featurestore(proto.Message): + r"""Featurestore configuration information on how the + Featurestore is configured. + + Attributes: + name (str): + Output only. Name of the Featurestore. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + display_name (str): + Required. The user-defined name of the + Featurestore. The name can be up to 128 + characters long and can consist of any UTF-8 + characters. + Display name of a Featurestore must be unique + within a single Project and Location Pair. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Featurestore + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Featurestore + was last updated. + etag (str): + Optional. Used to perform consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Featurestore.LabelsEntry]): + Optional. The labels with user-defined + metadata to organize your Featurestore. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + on and examples of labels. No more than 64 user + labels can be associated with one + Featurestore(System labels are excluded)." + System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + online_serving_config (google.cloud.aiplatform_v1beta1.types.Featurestore.OnlineServingConfig): + Required. Config for online serving + resources. + state (google.cloud.aiplatform_v1beta1.types.Featurestore.State): + Output only. State of the featurestore. + """ + class State(proto.Enum): + r"""Possible states a Featurestore can have.""" + STATE_UNSPECIFIED = 0 + STABLE = 1 + UPDATING = 2 + + class OnlineServingConfig(proto.Message): + r"""OnlineServingConfig specifies the details for provisioning + online serving resources. + + Attributes: + fixed_node_count (int): + Required. The number of nodes for each + cluster. The number of nodes will not scale + automatically but can be scaled manually by + providing different values when updating. + max_online_serving_size (int): + Maximum number of feature values per entity + that will be stored in online serving storage. + The Featurestore will retain the latest feature + values per entity and periodically remove any + older feature values. It can take up to a day + before the older feature values are deleted. + Storage infrastructure cost is propotional to + this value. Recommend to set to the largest + number of versions (i.e last-k) needed at online + serving time. If not set, default to 1. + """ + + fixed_node_count = proto.Field(proto.INT32, number=2) + + max_online_serving_size = proto.Field(proto.INT32, number=3) + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + create_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + + etag = proto.Field(proto.STRING, number=5) + + labels = proto.MapField(proto.STRING, proto.STRING, number=6) + + online_serving_config = proto.Field(proto.MESSAGE, number=7, + message=OnlineServingConfig, + ) + + state = proto.Field(proto.ENUM, number=8, + enum=State, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py b/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py new file mode 100644 index 0000000000..a13e0778f4 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import duration_pb2 as duration # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'FeaturestoreMonitoringConfig', + }, +) + + +class FeaturestoreMonitoringConfig(proto.Message): + r"""Configuration of how features in Featurestore are monitored. + + Attributes: + snapshot_analysis (google.cloud.aiplatform_v1beta1.types.FeaturestoreMonitoringConfig.SnapshotAnalysis): + The config for Snapshot Analysis Based + Feature Monitoring. + """ + class SnapshotAnalysis(proto.Message): + r"""Configuration of the Featurestore's Snapshot Analysis Based + Monitoring. This type of analysis generates statistics for each + Feature based on a snapshot of the latest feature value of each + entities every monitoring_interval. + + Attributes: + disabled (bool): + The monitoring schedule for snapshot analysis. For + EntityType-level config: unset / disabled = true indicates + disabled by default for Features under it; otherwise by + default enable snapshot analysis monitoring with + monitoring_interval for Features under it. Feature-level + config: disabled = true indicates disabled regardless of the + EntityType-level config; unset monitoring_interval indicates + going with EntityType-level config; otherwise run snapshot + analysis monitoring with monitoring_interval regardless of + the EntityType-level config. Explicitly Disable the snapshot + analysis based monitoring. + monitoring_interval (google.protobuf.duration_pb2.Duration): + Configuration of the snapshot analysis based + monitoring pipeline running interval. The value + is rolled up to full day. + """ + + disabled = proto.Field(proto.BOOL, number=1) + + monitoring_interval = proto.Field(proto.MESSAGE, number=2, + message=duration.Duration, + ) + + snapshot_analysis = proto.Field(proto.MESSAGE, number=1, + message=SnapshotAnalysis, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py new file mode 100644 index 0000000000..2564346039 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py @@ -0,0 +1,343 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import feature_selector as gca_feature_selector +from google.cloud.aiplatform_v1beta1.types import types +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ReadFeatureValuesRequest', + 'ReadSetting', + 'ReadFeatureValuesResponse', + 'StreamingReadFeatureValuesRequest', + 'FeatureValue', + 'FeatureValueList', + }, +) + + +class ReadFeatureValuesRequest(proto.Message): + r"""Request message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + + Attributes: + entity_type (str): + Required. The resource name of the EntityType for the entity + being read. Value format: + ``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``. + For example, for a machine learning model predicting user + clicks on a website, an EntityType ID could be "user". + entity_id (str): + Required. ID for a specific entity. For example, for a + machine learning model predicting user clicks on a website, + an entity ID could be "user_123". + feature_selector (google.cloud.aiplatform_v1beta1.types.FeatureSelector): + Required. Selector choosing Features of the + target EntityType. + setting (google.cloud.aiplatform_v1beta1.types.ReadSetting): + Setting to apply to all Feature values being + read, by default. + setting_overrides (Sequence[google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesRequest.SettingOverridesEntry]): + Map from Feature ID to settings to apply to Feature values + being read. If no setting is specified for a Feature + selected by + ``ReadFeatureValuesRequest.feature_selector``, + the default + ``ReadFeatureValuesRequest.setting`` + will be used. + """ + + entity_type = proto.Field(proto.STRING, number=1) + + entity_id = proto.Field(proto.STRING, number=2) + + feature_selector = proto.Field(proto.MESSAGE, number=3, + message=gca_feature_selector.FeatureSelector, + ) + + setting = proto.Field(proto.MESSAGE, number=5, + message='ReadSetting', + ) + + setting_overrides = proto.MapField(proto.STRING, proto.MESSAGE, number=6, + message='ReadSetting', + ) + + +class ReadSetting(proto.Message): + r"""Setting to apply when reading Feature values, e.g. "limit + read to the K-latest values". + + Attributes: + values_count (int): + Number of values, successive in time, to + retrieve for a Feature. If not set, default to + 1. Must be less than or equal to 32. + read_time (google.protobuf.timestamp_pb2.Timestamp): + Retrieve latest values before or at this + timestamp. If not set, retrieve latest values. + Resolution in millisecond. Request will fail if + timestamp is not millisecond-aligned. + """ + + values_count = proto.Field(proto.INT32, number=2) + + read_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) + + +class ReadFeatureValuesResponse(proto.Message): + r"""Response message for + ``FeaturestoreOnlineServingService.ReadFeatureValues``. + + Attributes: + header (google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse.Header): + Response header. + entity_view (google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse.EntityView): + Entity view with Feature values. This may be + the entity in the Featurestore if values for all + Features were requested, or a projection of the + entity in the Featurestore if values for only + some Features were requested. + """ + class FeatureDescriptor(proto.Message): + r"""Metadata for requested Features. + + Attributes: + id (str): + Feature ID. + """ + + id = proto.Field(proto.STRING, number=1) + + class Header(proto.Message): + r"""Response header with metadata for the requested + ``ReadFeatureValuesRequest.entity_type`` + and Features. + + Attributes: + entity_type (str): + The resource name of the EntityType from the + ``ReadFeatureValuesRequest``. + Value format: + ``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``. + feature_descriptors (Sequence[google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse.FeatureDescriptor]): + List of Feature metadata corresponding to each piece of + [ReadFeatureValuesResponse.data][]. + """ + + entity_type = proto.Field(proto.STRING, number=1) + + feature_descriptors = proto.RepeatedField(proto.MESSAGE, number=2, + message='ReadFeatureValuesResponse.FeatureDescriptor', + ) + + class EntityView(proto.Message): + r"""Entity view with Feature values. + + Attributes: + entity_id (str): + ID of the requested entity. + data (Sequence[google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse.EntityView.Data]): + Each piece of data holds the k requested values for one + requested Feature. If no values for the requested Feature + exist, the corresponding cell will be empty. This has the + same size and is in the same order as the features from the + header + ``ReadFeatureValuesResponse.header``. + """ + class Data(proto.Message): + r"""Container to hold value(s), successive in time, for one + Feature from the request. + + Attributes: + value (google.cloud.aiplatform_v1beta1.types.FeatureValue): + Feature value if a single value is requested. + values (google.cloud.aiplatform_v1beta1.types.FeatureValueList): + Feature values list if values, successive in + time, are requested. If the requested number of + values is greater than the number of existing + Feature values, nonexistent values are omitted + instead of being returned as empty. + """ + + value = proto.Field(proto.MESSAGE, number=1, oneof='data', + message='FeatureValue', + ) + + values = proto.Field(proto.MESSAGE, number=2, oneof='data', + message='FeatureValueList', + ) + + entity_id = proto.Field(proto.STRING, number=1) + + data = proto.RepeatedField(proto.MESSAGE, number=2, + message='ReadFeatureValuesResponse.EntityView.Data', + ) + + header = proto.Field(proto.MESSAGE, number=1, + message=Header, + ) + + entity_view = proto.Field(proto.MESSAGE, number=2, + message=EntityView, + ) + + +class StreamingReadFeatureValuesRequest(proto.Message): + r"""Request message for + [FeaturestoreOnlineServingService.StreamingFeatureValuesRead][]. + + Attributes: + entity_type (str): + Required. The resource name of the entities' type. Value + format: + ``projects/{project}/locations/{location}/featurestores/ {featurestore}/entityTypes/{entityType}``. + For example, for a machine learning model predicting user + clicks on a website, an EntityType ID could be "user". + entity_ids (Sequence[str]): + Required. IDs of entities to read Feature values of. For + example, for a machine learning model predicting user clicks + on a website, an entity ID could be "user_123". + feature_selector (google.cloud.aiplatform_v1beta1.types.FeatureSelector): + Required. Selector choosing Features of the + target EntityType. + setting (google.cloud.aiplatform_v1beta1.types.ReadSetting): + Setting to apply to all Feature values being + read, by default. + setting_overrides (Sequence[google.cloud.aiplatform_v1beta1.types.StreamingReadFeatureValuesRequest.SettingOverridesEntry]): + Map from Feature ID to settings to apply to Feature values + being read. If no setting is specified for a Feature + selected by + ``ReadFeatureValuesRequest.feature_selector``, + the default + ``ReadFeatureValuesRequest.setting`` + will be used. + """ + + entity_type = proto.Field(proto.STRING, number=1) + + entity_ids = proto.RepeatedField(proto.STRING, number=2) + + feature_selector = proto.Field(proto.MESSAGE, number=3, + message=gca_feature_selector.FeatureSelector, + ) + + setting = proto.Field(proto.MESSAGE, number=5, + message='ReadSetting', + ) + + setting_overrides = proto.MapField(proto.STRING, proto.MESSAGE, number=6, + message='ReadSetting', + ) + + +class FeatureValue(proto.Message): + r"""Value for a feature. + NEXT ID: 15 + + Attributes: + bool_value (bool): + Bool type feature value. + double_value (float): + Double type feature value. + int64_value (int): + Int64 feature value. + string_value (str): + String feature value. + bool_array_value (google.cloud.aiplatform_v1beta1.types.BoolArray): + A list of bool type feature value. + double_array_value (google.cloud.aiplatform_v1beta1.types.DoubleArray): + A list of double type feature value. + int64_array_value (google.cloud.aiplatform_v1beta1.types.Int64Array): + A list of int64 type feature value. + string_array_value (google.cloud.aiplatform_v1beta1.types.StringArray): + A list of string type feature value. + bytes_value (bytes): + Bytes feature value. + metadata (google.cloud.aiplatform_v1beta1.types.FeatureValue.Metadata): + Output only. Metadata of feature value. + """ + class Metadata(proto.Message): + r"""Metadata of feature value. + + Attributes: + generate_time (google.protobuf.timestamp_pb2.Timestamp): + Feature generation timestamp. Typically, it + is provided by user at feature ingestion time. + If not, feature store will use the system + timestamp when the data is ingested into feature + store. + """ + + generate_time = proto.Field(proto.MESSAGE, number=1, + message=timestamp.Timestamp, + ) + + bool_value = proto.Field(proto.BOOL, number=1, oneof='value') + + double_value = proto.Field(proto.DOUBLE, number=2, oneof='value') + + int64_value = proto.Field(proto.INT64, number=5, oneof='value') + + string_value = proto.Field(proto.STRING, number=6, oneof='value') + + bool_array_value = proto.Field(proto.MESSAGE, number=7, oneof='value', + message=types.BoolArray, + ) + + double_array_value = proto.Field(proto.MESSAGE, number=8, oneof='value', + message=types.DoubleArray, + ) + + int64_array_value = proto.Field(proto.MESSAGE, number=11, oneof='value', + message=types.Int64Array, + ) + + string_array_value = proto.Field(proto.MESSAGE, number=12, oneof='value', + message=types.StringArray, + ) + + bytes_value = proto.Field(proto.BYTES, number=13, oneof='value') + + metadata = proto.Field(proto.MESSAGE, number=14, + message=Metadata, + ) + + +class FeatureValueList(proto.Message): + r"""Container for list of values. + + Attributes: + values (Sequence[google.cloud.aiplatform_v1beta1.types.FeatureValue]): + A list of feature values. All of them should + be the same data type. + """ + + values = proto.RepeatedField(proto.MESSAGE, number=1, + message='FeatureValue', + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py new file mode 100644 index 0000000000..c5d1b4034f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py @@ -0,0 +1,1202 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type +from google.cloud.aiplatform_v1beta1.types import feature as gca_feature +from google.cloud.aiplatform_v1beta1.types import feature_selector as gca_feature_selector +from google.cloud.aiplatform_v1beta1.types import featurestore as gca_featurestore +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import operation +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'CreateFeaturestoreRequest', + 'GetFeaturestoreRequest', + 'ListFeaturestoresRequest', + 'ListFeaturestoresResponse', + 'UpdateFeaturestoreRequest', + 'DeleteFeaturestoreRequest', + 'ImportFeatureValuesRequest', + 'ImportFeatureValuesResponse', + 'BatchReadFeatureValuesRequest', + 'DestinationFeatureSetting', + 'FeatureValueDestination', + 'BatchReadFeatureValuesResponse', + 'CreateEntityTypeRequest', + 'GetEntityTypeRequest', + 'ListEntityTypesRequest', + 'ListEntityTypesResponse', + 'UpdateEntityTypeRequest', + 'DeleteEntityTypeRequest', + 'CreateFeatureRequest', + 'BatchCreateFeaturesRequest', + 'BatchCreateFeaturesResponse', + 'GetFeatureRequest', + 'ListFeaturesRequest', + 'ListFeaturesResponse', + 'SearchFeaturesRequest', + 'SearchFeaturesResponse', + 'UpdateFeatureRequest', + 'DeleteFeatureRequest', + 'CreateFeaturestoreOperationMetadata', + 'UpdateFeaturestoreOperationMetadata', + 'ImportFeatureValuesOperationMetadata', + 'BatchReadFeatureValuesOperationMetadata', + 'CreateEntityTypeOperationMetadata', + 'CreateFeatureOperationMetadata', + 'BatchCreateFeaturesOperationMetadata', + }, +) + + +class CreateFeaturestoreRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.CreateFeaturestore``. + + Attributes: + parent (str): + Required. The resource name of the Location to create + Featurestores. Format: + ``projects/{project}/locations/{location}'`` + featurestore (google.cloud.aiplatform_v1beta1.types.Featurestore): + Required. The Featurestore to create. + featurestore_id (str): + Required. The ID to use for this Featurestore, which will + become the final component of the Featurestore's resource + name. + + This value may be up to 60 characters, and valid characters + are ``[a-z0-9_]``. The first character cannot be a number. + + The value must be unique within the project and location. + """ + + parent = proto.Field(proto.STRING, number=1) + + featurestore = proto.Field(proto.MESSAGE, number=2, + message=gca_featurestore.Featurestore, + ) + + featurestore_id = proto.Field(proto.STRING, number=3) + + +class GetFeaturestoreRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.GetFeaturestore``. + + Attributes: + name (str): + Required. The name of the Featurestore + resource. + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListFeaturestoresRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.ListFeaturestores``. + + Attributes: + parent (str): + Required. The resource name of the Location to list + Featurestores. Format: + ``projects/{project}/locations/{location}`` + filter (str): + Lists the featurestores that match the filter expression. + The following fields are supported: + + - ``display_name``: Supports =, != comparisons. + - ``create_time``: Supports =, !=, <, >, <=, and >= + comparisons. Values must be in RFC 3339 format. + - ``update_time``: Supports =, !=, <, >, <=, and >= + comparisons. Values must be in RFC 3339 format. + - ``online_serving_config.fixed_node_count``: Supports =, + !=, <, >, <=, and >= comparisons. + - ``labels``: Supports key-value equality and key presence. + + Examples: + + - ``create_time > "2020-01-01" OR update_time > "2020-01-01"`` + Featurestores created or updated after 2020-01-01. + - ``labels.env = "prod"`` Featurestores with label "env" + set to "prod". + page_size (int): + The maximum number of Featurestores to + return. The service may return fewer than this + value. If unspecified, at most 100 Featurestores + will be returned. The maximum value is 100; any + value greater than 100 will be coerced to 100. + page_token (str): + A page token, received from a previous + ``FeaturestoreService.ListFeaturestores`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``FeaturestoreService.ListFeaturestores`` + must match the call that provided the page token. + order_by (str): + A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for + descending. Supported Fields: + + - ``display_name`` + - ``create_time`` + - ``update_time`` + - ``online_serving_config.fixed_node_count`` + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + read_mask = proto.Field(proto.MESSAGE, number=6, + message=field_mask.FieldMask, + ) + + +class ListFeaturestoresResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.ListFeaturestores``. + + Attributes: + featurestores (Sequence[google.cloud.aiplatform_v1beta1.types.Featurestore]): + The Featurestores matching the request. + next_page_token (str): + A token, which can be sent as + ``ListFeaturestoresRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + featurestores = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_featurestore.Featurestore, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateFeaturestoreRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.UpdateFeaturestore``. + + Attributes: + featurestore (google.cloud.aiplatform_v1beta1.types.Featurestore): + Required. The Featurestore's ``name`` field is used to + identify the Featurestore to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Field mask is used to specify the fields to be overwritten + in the Featurestore resource by the update. The fields + specified in the update_mask are relative to the resource, + not the full request. A field will be overwritten if it is + in the mask. If the user does not provide a mask then only + the non-empty fields present in the request will be + overwritten. Set the update_mask to ``*`` to override all + fields. + + Updatable fields: + + - ``display_name`` + - ``labels`` + - ``online_serving_config.fixed_node_count`` + - ``online_serving_config.max_online_serving_size`` + """ + + featurestore = proto.Field(proto.MESSAGE, number=1, + message=gca_featurestore.Featurestore, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + +class DeleteFeaturestoreRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.DeleteFeaturestore``. + + Attributes: + name (str): + Required. The name of the Featurestore to be deleted. + Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + force (bool): + If set to true, any EntityTypes and Features + for this Featurestore will also be deleted. + (Otherwise, the request will only work if the + Featurestore has no EntityTypes.) + """ + + name = proto.Field(proto.STRING, number=1) + + force = proto.Field(proto.BOOL, number=2) + + +class ImportFeatureValuesRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.ImportFeatureValues``. + + Attributes: + avro_source (google.cloud.aiplatform_v1beta1.types.AvroSource): + + bigquery_source (google.cloud.aiplatform_v1beta1.types.BigQuerySource): + + csv_source (google.cloud.aiplatform_v1beta1.types.CsvSource): + + feature_time_field (str): + Source column that holds the Feature + timestamp for all Feature values in each entity. + feature_time (google.protobuf.timestamp_pb2.Timestamp): + Single Feature timestamp for all entities + being imported. The timestamp must not have + higher than millisecond precision. + entity_type (str): + Required. The resource name of the EntityType grouping the + Features for which values are being imported. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}`` + entity_id_field (str): + Source column that holds entity IDs. If not provided, entity + IDs are extracted from the column named ``entity_id``. + feature_specs (Sequence[google.cloud.aiplatform_v1beta1.types.ImportFeatureValuesRequest.FeatureSpec]): + Required. Specifications defining which Feature values to + import from the entity. The request fails if no + feature_specs are provided, and having multiple + feature_specs for one Feature is not allowed. + disable_online_serving (bool): + If set, data will not be imported for online + serving. This is typically used for backfilling, + where Feature generation timestamps are not in + the timestamp range needed for online serving. + worker_count (int): + Required. Specifies the number of workers + that are used to write data to the Featurestore. + Consider the online serving capacity that you + require to achieve the desired import throughput + without interfering with online serving. The + value must be greater than 0, and less than or + equal to 100. + """ + class FeatureSpec(proto.Message): + r"""Defines the Feature value(s) to import. + + Attributes: + id (str): + Required. ID of the Feature to import values + of. This Feature must exist in the target + EntityType, or the request will fail. + source_field (str): + Source column to get the Feature values from. + If not set, uses the column with the same name + as the Feature ID. + """ + + id = proto.Field(proto.STRING, number=1) + + source_field = proto.Field(proto.STRING, number=2) + + avro_source = proto.Field(proto.MESSAGE, number=2, oneof='source', + message=io.AvroSource, + ) + + bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', + message=io.BigQuerySource, + ) + + csv_source = proto.Field(proto.MESSAGE, number=4, oneof='source', + message=io.CsvSource, + ) + + feature_time_field = proto.Field(proto.STRING, number=6, oneof='feature_time_source') + + feature_time = proto.Field(proto.MESSAGE, number=7, oneof='feature_time_source', + message=timestamp.Timestamp, + ) + + entity_type = proto.Field(proto.STRING, number=1) + + entity_id_field = proto.Field(proto.STRING, number=5) + + feature_specs = proto.RepeatedField(proto.MESSAGE, number=8, + message=FeatureSpec, + ) + + disable_online_serving = proto.Field(proto.BOOL, number=9) + + worker_count = proto.Field(proto.INT32, number=11) + + +class ImportFeatureValuesResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.ImportFeatureValues``. + + Attributes: + imported_entity_count (int): + Number of entities that have been imported by + the operation. + imported_feature_value_count (int): + Number of Feature values that have been + imported by the operation. + """ + + imported_entity_count = proto.Field(proto.INT64, number=1) + + imported_feature_value_count = proto.Field(proto.INT64, number=2) + + +class BatchReadFeatureValuesRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.BatchReadFeatureValues``. + + Attributes: + csv_read_instances (google.cloud.aiplatform_v1beta1.types.CsvSource): + Each read instance consists of exactly one read timestamp + and one or more entity IDs identifying entities of the + corresponding EntityTypes whose Features are requested. + + Each output instance contains Feature values of requested + entities concatenated together as of the read time. + + An example read instance may be + ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z``. + + An example output instance may be + ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z, foo_entity_feature1_value, bar_entity_feature2_value``. + + Timestamp in each read instance must be millisecond-aligned. + + ``csv_read_instances`` are read instances stored in a + plain-text CSV file. The header should be: + [ENTITY_TYPE_ID1], [ENTITY_TYPE_ID2], ..., timestamp + + The columns can be in any order. + + Values in the timestamp column must use the RFC 3339 format, + e.g. ``2012-07-30T10:43:17.123Z``. + featurestore (str): + Required. The resource name of the Featurestore from which + to query Feature values. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + destination (google.cloud.aiplatform_v1beta1.types.FeatureValueDestination): + Required. Specifies output location and + format. + entity_type_specs (Sequence[google.cloud.aiplatform_v1beta1.types.BatchReadFeatureValuesRequest.EntityTypeSpec]): + Required. Specifies EntityType grouping Features to read + values of and settings. Each EntityType referenced in + [BatchReadFeatureValuesRequest.entity_type_specs] must have + a column specifying entity IDs in tha EntityType in + [BatchReadFeatureValuesRequest.request][] . + """ + class EntityTypeSpec(proto.Message): + r"""Selects Features of an EntityType to read values of and + specifies read settings. + + Attributes: + entity_type_id (str): + Required. ID of the EntityType to select Features. The + EntityType id is the + ``entity_type_id`` + specified during EntityType creation. + feature_selector (google.cloud.aiplatform_v1beta1.types.FeatureSelector): + Required. Selectors choosing which Feature + values to read from the EntityType. + settings (Sequence[google.cloud.aiplatform_v1beta1.types.DestinationFeatureSetting]): + Per-Feature settings for the batch read. + """ + + entity_type_id = proto.Field(proto.STRING, number=1) + + feature_selector = proto.Field(proto.MESSAGE, number=2, + message=gca_feature_selector.FeatureSelector, + ) + + settings = proto.RepeatedField(proto.MESSAGE, number=3, + message='DestinationFeatureSetting', + ) + + csv_read_instances = proto.Field(proto.MESSAGE, number=3, oneof='read_option', + message=io.CsvSource, + ) + + featurestore = proto.Field(proto.STRING, number=1) + + destination = proto.Field(proto.MESSAGE, number=4, + message='FeatureValueDestination', + ) + + entity_type_specs = proto.RepeatedField(proto.MESSAGE, number=7, + message=EntityTypeSpec, + ) + + +class DestinationFeatureSetting(proto.Message): + r""" + + Attributes: + feature_id (str): + Required. The ID of the Feature to apply the + setting to. + destination_field (str): + Specify the field name in the export + destination. If not specified, Feature ID is + used. + """ + + feature_id = proto.Field(proto.STRING, number=1) + + destination_field = proto.Field(proto.STRING, number=2) + + +class FeatureValueDestination(proto.Message): + r"""A destination location for Feature values and format. + + Attributes: + bigquery_destination (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): + Output in BigQuery format. output_uri in + ``FeatureValueDestination.bigquery_destination`` + must refer to a table. + tfrecord_destination (google.cloud.aiplatform_v1beta1.types.TFRecordDestination): + Output in TFRecord format. + + Below are the mapping from Feature value type in + Featurestore to Feature value type in TFRecord: + + :: + + Value type in Featurestore | Value type in TFRecord + DOUBLE, DOUBLE_ARRAY | FLOAT_LIST + INT64, INT64_ARRAY | INT64_LIST + STRING, STRING_ARRAY, BYTES | BYTES_LIST + true -> byte_string("true"), false -> byte_string("false") + BOOL, BOOL_ARRAY (true, false) | BYTES_LIST + csv_destination (google.cloud.aiplatform_v1beta1.types.CsvDestination): + Output in CSV format. Array Feature value + types are not allowed in CSV format. + """ + + bigquery_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', + message=io.BigQueryDestination, + ) + + tfrecord_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', + message=io.TFRecordDestination, + ) + + csv_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', + message=io.CsvDestination, + ) + + +class BatchReadFeatureValuesResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.BatchReadFeatureValues``. + """ + + +class CreateEntityTypeRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.CreateEntityType``. + + Attributes: + parent (str): + Required. The resource name of the Featurestore to create + EntityTypes. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + entity_type (google.cloud.aiplatform_v1beta1.types.EntityType): + The EntityType to create. + entity_type_id (str): + Required. The ID to use for the EntityType, which will + become the final component of the EntityType's resource + name. + + This value may be up to 60 characters, and valid characters + are ``[a-z0-9_]``. The first character cannot be a number. + + The value must be unique within a featurestore. + """ + + parent = proto.Field(proto.STRING, number=1) + + entity_type = proto.Field(proto.MESSAGE, number=2, + message=gca_entity_type.EntityType, + ) + + entity_type_id = proto.Field(proto.STRING, number=3) + + +class GetEntityTypeRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.GetEntityType``. + + Attributes: + name (str): + Required. The name of the EntityType resource. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListEntityTypesRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.ListEntityTypes``. + + Attributes: + parent (str): + Required. The resource name of the Featurestore to list + EntityTypes. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}`` + filter (str): + Lists the EntityTypes that match the filter expression. The + following filters are supported: + + - ``create_time``: Supports =, !=, <, >, >=, and <= + comparisons. Values must be in RFC 3339 format. + - ``update_time``: Supports =, !=, <, >, >=, and <= + comparisons. Values must be in RFC 3339 format. + - ``labels``: Supports key-value equality as well as key + presence. + + Examples: + + - ``create_time > \"2020-01-31T15:30:00.000000Z\" OR update_time > \"2020-01-31T15:30:00.000000Z\"`` + --> EntityTypes created or updated after + 2020-01-31T15:30:00.000000Z. + - ``labels.active = yes AND labels.env = prod`` --> + EntityTypes having both (active: yes) and (env: prod) + labels. + - ``labels.env: *`` --> Any EntityType which has a label + with 'env' as the key. + page_size (int): + The maximum number of EntityTypes to return. + The service may return fewer than this value. If + unspecified, at most 1000 EntityTypes will be + returned. The maximum value is 1000; any value + greater than 1000 will be coerced to 1000. + page_token (str): + A page token, received from a previous + ``FeaturestoreService.ListEntityTypes`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``FeaturestoreService.ListEntityTypes`` + must match the call that provided the page token. + order_by (str): + A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for + descending. + + Supported fields: + + - ``entity_type_id`` + - ``create_time`` + - ``update_time`` + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + read_mask = proto.Field(proto.MESSAGE, number=6, + message=field_mask.FieldMask, + ) + + +class ListEntityTypesResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.ListEntityTypes``. + + Attributes: + entity_types (Sequence[google.cloud.aiplatform_v1beta1.types.EntityType]): + The EntityTypes matching the request. + next_page_token (str): + A token, which can be sent as + ``ListEntityTypesRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + entity_types = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_entity_type.EntityType, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateEntityTypeRequest(proto.Message): + r"""Request message for [FeaturestoreService.UpdateEntityTypes][]. + + Attributes: + entity_type (google.cloud.aiplatform_v1beta1.types.EntityType): + Required. The EntityType's ``name`` field is used to + identify the EntityType to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Field mask is used to specify the fields to be overwritten + in the EntityType resource by the update. The fields + specified in the update_mask are relative to the resource, + not the full request. A field will be overwritten if it is + in the mask. If the user does not provide a mask then only + the non-empty fields present in the request will be + overwritten. Set the update_mask to ``*`` to override all + fields. + + Updatable fields: + + - ``description`` + - ``labels`` + """ + + entity_type = proto.Field(proto.MESSAGE, number=1, + message=gca_entity_type.EntityType, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + +class DeleteEntityTypeRequest(proto.Message): + r"""Request message for [FeaturestoreService.DeleteEntityTypes][]. + + Attributes: + name (str): + Required. The name of the EntityType to be deleted. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + force (bool): + If set to true, any Features for this + EntityType will also be deleted. (Otherwise, the + request will only work if the EntityType has no + Features.) + """ + + name = proto.Field(proto.STRING, number=1) + + force = proto.Field(proto.BOOL, number=2) + + +class CreateFeatureRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.CreateFeature``. + + Attributes: + parent (str): + Required. The resource name of the EntityType to create a + Feature. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + feature (google.cloud.aiplatform_v1beta1.types.Feature): + Required. The Feature to create. + feature_id (str): + Required. The ID to use for the Feature, which will become + the final component of the Feature's resource name. + + This value may be up to 60 characters, and valid characters + are ``[a-z0-9_]``. The first character cannot be a number. + + The value must be unique within an entitytype. + """ + + parent = proto.Field(proto.STRING, number=1) + + feature = proto.Field(proto.MESSAGE, number=2, + message=gca_feature.Feature, + ) + + feature_id = proto.Field(proto.STRING, number=3) + + +class BatchCreateFeaturesRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.BatchCreateFeatures``. + + Attributes: + parent (str): + Required. The resource name of the EntityType to create the + batch of Features under. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + requests (Sequence[google.cloud.aiplatform_v1beta1.types.CreateFeatureRequest]): + Required. The request message specifying the Features to + create. All Features must be created under the same parent + EntityType. The ``parent`` field in each child request + message can be omitted. If ``parent`` is set in a child + request, then the value must match the ``parent`` value in + this request message. + """ + + parent = proto.Field(proto.STRING, number=1) + + requests = proto.RepeatedField(proto.MESSAGE, number=2, + message='CreateFeatureRequest', + ) + + +class BatchCreateFeaturesResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.BatchCreateFeatures``. + + Attributes: + features (Sequence[google.cloud.aiplatform_v1beta1.types.Feature]): + The Features created. + """ + + features = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_feature.Feature, + ) + + +class GetFeatureRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.GetFeature``. + + Attributes: + name (str): + Required. The name of the Feature resource. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListFeaturesRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.ListFeatures``. + + Attributes: + parent (str): + Required. The resource name of the Location to list + Features. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + filter (str): + Lists the Features that match the filter expression. The + following filters are supported: + + - ``value_type``: Supports = and != comparisons. + - ``create_time``: Supports =, !=, <, >, >=, and <= + comparisons. Values must be in RFC 3339 format. + - ``update_time``: Supports =, !=, <, >, >=, and <= + comparisons. Values must be in RFC 3339 format. + - ``labels``: Supports key-value equality as well as key + presence. + + Examples: + + - ``value_type = DOUBLE`` --> Features whose type is + DOUBLE. + - ``create_time > \"2020-01-31T15:30:00.000000Z\" OR update_time > \"2020-01-31T15:30:00.000000Z\"`` + --> EntityTypes created or updated after + 2020-01-31T15:30:00.000000Z. + - ``labels.active = yes AND labels.env = prod`` --> + Features having both (active: yes) and (env: prod) + labels. + - ``labels.env: *`` --> Any Feature which has a label with + 'env' as the key. + page_size (int): + The maximum number of Features to return. The + service may return fewer than this value. If + unspecified, at most 1000 Features will be + returned. The maximum value is 1000; any value + greater than 1000 will be coerced to 1000. + page_token (str): + A page token, received from a previous + ``FeaturestoreService.ListFeatures`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``FeaturestoreService.ListFeatures`` + must match the call that provided the page token. + order_by (str): + A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for + descending. Supported fields: + + - ``feature_id`` + - ``value_type`` + - ``create_time`` + - ``update_time`` + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + latest_stats_count (int): + If set, return the most recent + ``ListFeaturesRequest.latest_stats_count`` + of stats for each Feature in response. Valid value is [0, + 10]. If number of stats exists < + ``ListFeaturesRequest.latest_stats_count``, + return all existing stats. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + read_mask = proto.Field(proto.MESSAGE, number=6, + message=field_mask.FieldMask, + ) + + latest_stats_count = proto.Field(proto.INT32, number=7) + + +class ListFeaturesResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.ListFeatures``. + + Attributes: + features (Sequence[google.cloud.aiplatform_v1beta1.types.Feature]): + The Features matching the request. + next_page_token (str): + A token, which can be sent as + ``ListFeaturesRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + features = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_feature.Feature, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class SearchFeaturesRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.SearchFeatures``. + + Attributes: + location (str): + Required. The resource name of the Location to search + Features. Format: + ``projects/{project}/locations/{location}`` + query (str): + Query string that is a conjunction of field-restricted + queries and/or field-restricted filters. Field-restricted + queries and filters can be combined using ``AND`` to form a + conjunction. + + A field query is in the form FIELD:QUERY. This implicitly + checks if QUERY exists as a substring within Feature's + FIELD. The QUERY and the FIELD are converted to a sequence + of words (i.e. tokens) for comparison. This is done by: + + - Removing leading/trailing whitespace and tokenizing the + search value. Characters that are not one of alphanumeric + [a-zA-Z0-9], underscore [_], or asterisk [*] are treated + as delimiters for tokens. (*) is treated as a wildcard + that matches characters within a token. + - Ignoring case. + - Prepending an asterisk to the first and appending an + asterisk to the last token in QUERY. + + A QUERY must be either a singular token or a phrase. A + phrase is one or multiple words enclosed in double quotation + marks ("). With phrases, the order of the words is + important. Words in the phrase must be matching in order and + consecutively. + + Supported FIELDs for field-restricted queries: + + - ``feature_id`` + - ``description`` + - ``entity_type_id`` + + Examples: + + - ``feature_id: foo`` --> Matches a Feature with ID + containing the substring ``foo`` (eg. ``foo``, + ``foofeature``, ``barfoo``). + - ``feature_id: foo*feature`` --> Matches a Feature with ID + containing the substring ``foo*feature`` (eg. + ``foobarfeature``). + - ``feature_id: foo AND description: bar`` --> Matches a + Feature with ID containing the substring ``foo`` and + description containing the substring ``bar``. + + Besides field queries, the following exact-match filters are + supported. The exact-match filters do not support wildcards. + Unlike field-restricted queries, exact-match filters are + case-sensitive. + + - ``feature_id``: Supports = comparisons. + - ``description``: Supports = comparisons. Multi-token + filters should be enclosed in quotes. + - ``entity_type_id``: Supports = comparisons. + - ``value_type``: Supports = and != comparisons. + - ``labels``: Supports key-value equality as well as key + presence. + - ``featurestore_id``: Supports = comparisons. + + Examples: + + - ``description = "foo bar"`` --> Any Feature with + description exactly equal to ``foo bar`` + - ``value_type = DOUBLE`` --> Features whose type is + DOUBLE. + - ``labels.active = yes AND labels.env = prod`` --> + Features having both (active: yes) and (env: prod) + labels. + - ``labels.env: *`` --> Any Feature which has a label with + ``env`` as the key. + page_size (int): + The maximum number of Features to return. The + service may return fewer than this value. If + unspecified, at most 100 Features will be + returned. The maximum value is 100; any value + greater than 100 will be coerced to 100. + page_token (str): + A page token, received from a previous + ``FeaturestoreService.SearchFeatures`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``FeaturestoreService.SearchFeatures``, + except ``page_size``, must match the call that provided the + page token. + """ + + location = proto.Field(proto.STRING, number=1) + + query = proto.Field(proto.STRING, number=3) + + page_size = proto.Field(proto.INT32, number=4) + + page_token = proto.Field(proto.STRING, number=5) + + +class SearchFeaturesResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.SearchFeatures``. + + Attributes: + features (Sequence[google.cloud.aiplatform_v1beta1.types.Feature]): + The Features matching the request. + + Fields returned: + + - ``name`` + - ``description`` + - ``labels`` + - ``create_time`` + - ``update_time`` + next_page_token (str): + A token, which can be sent as + ``SearchFeaturesRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + features = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_feature.Feature, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateFeatureRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.UpdateFeature``. + + Attributes: + feature (google.cloud.aiplatform_v1beta1.types.Feature): + Required. The Feature's ``name`` field is used to identify + the Feature to be updated. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}`` + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Field mask is used to specify the fields to be overwritten + in the Features resource by the update. The fields specified + in the update_mask are relative to the resource, not the + full request. A field will be overwritten if it is in the + mask. If the user does not provide a mask then only the + non-empty fields present in the request will be overwritten. + Set the update_mask to ``*`` to override all fields. + + Updatable fields: + + - ``description`` + - ``labels`` + - ``monitoring_config.snapshot_analysis.disabled`` + - ``monitoring_config.snapshot_analysis.monitoring_interval`` + """ + + feature = proto.Field(proto.MESSAGE, number=1, + message=gca_feature.Feature, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + +class DeleteFeatureRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.DeleteFeature``. + + Attributes: + name (str): + Required. The name of the Features to be deleted. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateFeaturestoreOperationMetadata(proto.Message): + r"""Details of operations that perform create Featurestore. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Featurestore. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class UpdateFeaturestoreOperationMetadata(proto.Message): + r"""Details of operations that perform update Featurestore. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Featurestore. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class ImportFeatureValuesOperationMetadata(proto.Message): + r"""Details of operations that perform import feature values. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Featurestore import + feature values. + imported_entity_count (int): + Number of entities that have been imported by + the operation. + imported_feature_value_count (int): + Number of feature values that have been + imported by the operation. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + imported_entity_count = proto.Field(proto.INT64, number=2) + + imported_feature_value_count = proto.Field(proto.INT64, number=3) + + +class BatchReadFeatureValuesOperationMetadata(proto.Message): + r"""Details of operations that batch reads Feature values. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Featurestore batch + read Features values. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class CreateEntityTypeOperationMetadata(proto.Message): + r"""Details of operations that perform create EntityType. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for EntityType. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class CreateFeatureOperationMetadata(proto.Message): + r"""Details of operations that perform create Feature. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Feature. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class BatchCreateFeaturesOperationMetadata(proto.Message): + r"""Details of operations that perform batch create Features. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Feature. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index 0d938b4628..3a3abf1b06 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -21,15 +21,45 @@ __protobuf__ = proto.module( package='google.cloud.aiplatform.v1beta1', manifest={ + 'AvroSource', + 'CsvSource', 'GcsSource', 'GcsDestination', 'BigQuerySource', 'BigQueryDestination', + 'CsvDestination', + 'TFRecordDestination', 'ContainerRegistryDestination', }, ) +class AvroSource(proto.Message): + r"""The storage details for Avro input content. + + Attributes: + gcs_source (google.cloud.aiplatform_v1beta1.types.GcsSource): + Required. Google Cloud Storage location. + """ + + gcs_source = proto.Field(proto.MESSAGE, number=1, + message='GcsSource', + ) + + +class CsvSource(proto.Message): + r"""The storage details for CSV input content. + + Attributes: + gcs_source (google.cloud.aiplatform_v1beta1.types.GcsSource): + Required. Google Cloud Storage location. + """ + + gcs_source = proto.Field(proto.MESSAGE, number=1, + message='GcsSource', + ) + + class GcsSource(proto.Message): r"""The Google Cloud Storage location for the input content. @@ -95,6 +125,32 @@ class BigQueryDestination(proto.Message): output_uri = proto.Field(proto.STRING, number=1) +class CsvDestination(proto.Message): + r"""The storage details for CSV output content. + + Attributes: + gcs_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): + Google Cloud Storage location. + """ + + gcs_destination = proto.Field(proto.MESSAGE, number=1, + message='GcsDestination', + ) + + +class TFRecordDestination(proto.Message): + r"""The storage details for TFRecord output content. + + Attributes: + gcs_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): + Google Cloud Storage location. + """ + + gcs_destination = proto.Field(proto.MESSAGE, number=1, + message='GcsDestination', + ) + + class ContainerRegistryDestination(proto.Message): r"""The Container Registry location for the container image. diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index aaa87f85bb..7a3f0e4567 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -428,8 +428,9 @@ class PredictSchemata(proto.Message): class ModelContainerSpec(proto.Message): - r"""Specification of a container for serving predictions. This message - is a subset of the Kubernetes Container v1 core + r"""Specification of a container for serving predictions. Some fields in + this message correspond to fields in the Kubernetes Container v1 + core `specification `__. Attributes: diff --git a/google/cloud/aiplatform_v1beta1/types/types.py b/google/cloud/aiplatform_v1beta1/types/types.py new file mode 100644 index 0000000000..c2803a3c3a --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/types.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'BoolArray', + 'DoubleArray', + 'Int64Array', + 'StringArray', + }, +) + + +class BoolArray(proto.Message): + r"""Bool list type feature value. + + Attributes: + values (Sequence[bool]): + A list of bool values. + """ + + values = proto.RepeatedField(proto.BOOL, number=1) + + +class DoubleArray(proto.Message): + r"""Double list type feature value. + + Attributes: + values (Sequence[float]): + A list of bool values. + """ + + values = proto.RepeatedField(proto.DOUBLE, number=1) + + +class Int64Array(proto.Message): + r"""Int64 list type feature value. + + Attributes: + values (Sequence[int]): + A list of int64 values. + """ + + values = proto.RepeatedField(proto.INT64, number=1) + + +class StringArray(proto.Message): + r"""A list of string values. + + Attributes: + values (Sequence[str]): + A list of string values. + """ + + values = proto.RepeatedField(proto.STRING, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/noxfile.py b/noxfile.py index 32bd822f2b..ab5a7296ad 100644 --- a/noxfile.py +++ b/noxfile.py @@ -214,9 +214,7 @@ def docfx(session): """Build the docfx yaml files for this library.""" session.install("-e", ".") - # sphinx-docfx-yaml supports up to sphinx version 1.5.5. - # https://github.com/docascode/sphinx-docfx-yaml/issues/97 - session.install("sphinx==1.5.5", "alabaster", "recommonmark", "sphinx-docfx-yaml") + session.install("sphinx", "alabaster", "recommonmark", "gcp-sphinx-docfx-yaml") shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( diff --git a/renovate.json b/renovate.json index f08bc22c9a..c04895563e 100644 --- a/renovate.json +++ b/renovate.json @@ -2,5 +2,8 @@ "extends": [ "config:base", ":preserveSemverRanges" ], - "ignorePaths": [".pre-commit-config.yaml"] + "ignorePaths": [".pre-commit-config.yaml"], + "pip_requirements": { + "fileMatch": ["requirements-test.txt", "samples/[\\S/]*constraints.txt", "samples/[\\S/]*constraints-test.txt"] + } } diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 2f1c62f3ef..ecf6f7286a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -1443,17 +1443,19 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - dataset = "clam" + location = "clam" + dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", } path = MigrationServiceClient.dataset_path(**expected) @@ -1463,19 +1465,17 @@ def test_parse_dataset_path(): assert expected == actual def test_dataset_path(): - project = "oyster" - location = "nudibranch" - dataset = "cuttlefish" + project = "cuttlefish" + dataset = "mussel" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", + "project": "winkle", "dataset": "nautilus", } diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py new file mode 100644 index 0000000000..3c99da7fac --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py @@ -0,0 +1,1292 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import FeaturestoreOnlineServingServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import FeaturestoreOnlineServingServiceClient +from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import transports +from google.cloud.aiplatform_v1beta1.types import feature_selector +from google.cloud.aiplatform_v1beta1.types import featurestore_online_service +from google.oauth2 import service_account +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(None) is None + assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [ + FeaturestoreOnlineServingServiceClient, + FeaturestoreOnlineServingServiceAsyncClient, +]) +def test_featurestore_online_serving_service_client_from_service_account_info(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +@pytest.mark.parametrize("client_class", [ + FeaturestoreOnlineServingServiceClient, + FeaturestoreOnlineServingServiceAsyncClient, +]) +def test_featurestore_online_serving_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_featurestore_online_serving_service_client_get_transport_class(): + transport = FeaturestoreOnlineServingServiceClient.get_transport_class() + available_transports = [ + transports.FeaturestoreOnlineServingServiceGrpcTransport, + ] + assert transport in available_transports + + transport = FeaturestoreOnlineServingServiceClient.get_transport_class("grpc") + assert transport == transports.FeaturestoreOnlineServingServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc"), + (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(FeaturestoreOnlineServingServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceClient)) +@mock.patch.object(FeaturestoreOnlineServingServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient)) +def test_featurestore_online_serving_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(FeaturestoreOnlineServingServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(FeaturestoreOnlineServingServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + + (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc", "true"), + (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc", "false"), + (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(FeaturestoreOnlineServingServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceClient)) +@mock.patch.object(FeaturestoreOnlineServingServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_featurestore_online_serving_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc"), + (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_featurestore_online_serving_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc"), + (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_featurestore_online_serving_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_featurestore_online_serving_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = FeaturestoreOnlineServingServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_read_feature_values(transport: str = 'grpc', request_type=featurestore_online_service.ReadFeatureValuesRequest): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_online_service.ReadFeatureValuesResponse( + ) + + response = client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_online_service.ReadFeatureValuesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) + + +def test_read_feature_values_from_dict(): + test_read_feature_values(request_type=dict) + + +def test_read_feature_values_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_feature_values), + '__call__') as call: + client.read_feature_values() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_online_service.ReadFeatureValuesRequest() + +@pytest.mark.asyncio +async def test_read_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_online_service.ReadFeatureValuesRequest): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_online_service.ReadFeatureValuesResponse( + )) + + response = await client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_online_service.ReadFeatureValuesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) + + +@pytest.mark.asyncio +async def test_read_feature_values_async_from_dict(): + await test_read_feature_values_async(request_type=dict) + + +def test_read_feature_values_field_headers(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_online_service.ReadFeatureValuesRequest() + request.entity_type = 'entity_type/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_feature_values), + '__call__') as call: + call.return_value = featurestore_online_service.ReadFeatureValuesResponse() + + client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type=entity_type/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_read_feature_values_field_headers_async(): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_online_service.ReadFeatureValuesRequest() + request.entity_type = 'entity_type/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_feature_values), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_online_service.ReadFeatureValuesResponse()) + + await client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type=entity_type/value', + ) in kw['metadata'] + + +def test_read_feature_values_flattened(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_online_service.ReadFeatureValuesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.read_feature_values( + entity_type='entity_type_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == 'entity_type_value' + + +def test_read_feature_values_flattened_error(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.read_feature_values( + featurestore_online_service.ReadFeatureValuesRequest(), + entity_type='entity_type_value', + ) + + +@pytest.mark.asyncio +async def test_read_feature_values_flattened_async(): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_online_service.ReadFeatureValuesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_online_service.ReadFeatureValuesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.read_feature_values( + entity_type='entity_type_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == 'entity_type_value' + + +@pytest.mark.asyncio +async def test_read_feature_values_flattened_error_async(): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.read_feature_values( + featurestore_online_service.ReadFeatureValuesRequest(), + entity_type='entity_type_value', + ) + + +def test_streaming_read_feature_values(transport: str = 'grpc', request_type=featurestore_online_service.StreamingReadFeatureValuesRequest): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.streaming_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + + response = client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, featurestore_online_service.ReadFeatureValuesResponse) + + +def test_streaming_read_feature_values_from_dict(): + test_streaming_read_feature_values(request_type=dict) + + +def test_streaming_read_feature_values_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.streaming_read_feature_values), + '__call__') as call: + client.streaming_read_feature_values() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + +@pytest.mark.asyncio +async def test_streaming_read_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_online_service.StreamingReadFeatureValuesRequest): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.streaming_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock(side_effect=[featurestore_online_service.ReadFeatureValuesResponse()]) + + response = await client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, featurestore_online_service.ReadFeatureValuesResponse) + + +@pytest.mark.asyncio +async def test_streaming_read_feature_values_async_from_dict(): + await test_streaming_read_feature_values_async(request_type=dict) + + +def test_streaming_read_feature_values_field_headers(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_online_service.StreamingReadFeatureValuesRequest() + request.entity_type = 'entity_type/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.streaming_read_feature_values), + '__call__') as call: + call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + + client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type=entity_type/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_streaming_read_feature_values_field_headers_async(): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_online_service.StreamingReadFeatureValuesRequest() + request.entity_type = 'entity_type/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.streaming_read_feature_values), + '__call__') as call: + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock(side_effect=[featurestore_online_service.ReadFeatureValuesResponse()]) + + await client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type=entity_type/value', + ) in kw['metadata'] + + +def test_streaming_read_feature_values_flattened(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.streaming_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.streaming_read_feature_values( + entity_type='entity_type_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == 'entity_type_value' + + +def test_streaming_read_feature_values_flattened_error(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.streaming_read_feature_values( + featurestore_online_service.StreamingReadFeatureValuesRequest(), + entity_type='entity_type_value', + ) + + +@pytest.mark.asyncio +async def test_streaming_read_feature_values_flattened_async(): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.streaming_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.streaming_read_feature_values( + entity_type='entity_type_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == 'entity_type_value' + + +@pytest.mark.asyncio +async def test_streaming_read_feature_values_flattened_error_async(): + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.streaming_read_feature_values( + featurestore_online_service.StreamingReadFeatureValuesRequest(), + entity_type='entity_type_value', + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FeaturestoreOnlineServingServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FeaturestoreOnlineServingServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = FeaturestoreOnlineServingServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.FeaturestoreOnlineServingServiceGrpcTransport, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + ) + + +def test_featurestore_online_serving_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.FeaturestoreOnlineServingServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_featurestore_online_serving_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.FeaturestoreOnlineServingServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'read_feature_values', + 'streaming_read_feature_values', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + +def test_featurestore_online_serving_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.FeaturestoreOnlineServingServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_featurestore_online_serving_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.FeaturestoreOnlineServingServiceTransport() + adc.assert_called_once() + + +def test_featurestore_online_serving_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + FeaturestoreOnlineServingServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_featurestore_online_serving_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.FeaturestoreOnlineServingServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize("transport_class", [transports.FeaturestoreOnlineServingServiceGrpcTransport, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport]) +def test_featurestore_online_serving_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + + +def test_featurestore_online_serving_service_host_no_port(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_featurestore_online_serving_service_host_with_port(): + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_featurestore_online_serving_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_featurestore_online_serving_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.FeaturestoreOnlineServingServiceGrpcTransport, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport]) +def test_featurestore_online_serving_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.FeaturestoreOnlineServingServiceGrpcTransport, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport]) +def test_featurestore_online_serving_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_entity_type_path(): + project = "squid" + location = "clam" + featurestore = "whelk" + entity_type = "octopus" + + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) + actual = FeaturestoreOnlineServingServiceClient.entity_type_path(project, location, featurestore, entity_type) + assert expected == actual + + +def test_parse_entity_type_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "featurestore": "cuttlefish", + "entity_type": "mussel", + + } + path = FeaturestoreOnlineServingServiceClient.entity_type_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreOnlineServingServiceClient.parse_entity_type_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "winkle" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = FeaturestoreOnlineServingServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nautilus", + + } + path = FeaturestoreOnlineServingServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreOnlineServingServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "scallop" + + expected = "folders/{folder}".format(folder=folder, ) + actual = FeaturestoreOnlineServingServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "abalone", + + } + path = FeaturestoreOnlineServingServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreOnlineServingServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "squid" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = FeaturestoreOnlineServingServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "clam", + + } + path = FeaturestoreOnlineServingServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreOnlineServingServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "whelk" + + expected = "projects/{project}".format(project=project, ) + actual = FeaturestoreOnlineServingServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "octopus", + + } + path = FeaturestoreOnlineServingServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreOnlineServingServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "oyster" + location = "nudibranch" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = FeaturestoreOnlineServingServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "cuttlefish", + "location": "mussel", + + } + path = FeaturestoreOnlineServingServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreOnlineServingServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.FeaturestoreOnlineServingServiceTransport, '_prep_wrapped_messages') as prep: + client = FeaturestoreOnlineServingServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.FeaturestoreOnlineServingServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = FeaturestoreOnlineServingServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py new file mode 100644 index 0000000000..74c27eb5a3 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py @@ -0,0 +1,6377 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.featurestore_service import FeaturestoreServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.featurestore_service import FeaturestoreServiceClient +from google.cloud.aiplatform_v1beta1.services.featurestore_service import pagers +from google.cloud.aiplatform_v1beta1.services.featurestore_service import transports +from google.cloud.aiplatform_v1beta1.types import entity_type +from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type +from google.cloud.aiplatform_v1beta1.types import feature +from google.cloud.aiplatform_v1beta1.types import feature as gca_feature +from google.cloud.aiplatform_v1beta1.types import feature_monitoring_stats +from google.cloud.aiplatform_v1beta1.types import feature_selector +from google.cloud.aiplatform_v1beta1.types import featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore as gca_featurestore +from google.cloud.aiplatform_v1beta1.types import featurestore_monitoring +from google.cloud.aiplatform_v1beta1.types import featurestore_service +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert FeaturestoreServiceClient._get_default_mtls_endpoint(None) is None + assert FeaturestoreServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert FeaturestoreServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert FeaturestoreServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert FeaturestoreServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert FeaturestoreServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [ + FeaturestoreServiceClient, + FeaturestoreServiceAsyncClient, +]) +def test_featurestore_service_client_from_service_account_info(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +@pytest.mark.parametrize("client_class", [ + FeaturestoreServiceClient, + FeaturestoreServiceAsyncClient, +]) +def test_featurestore_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_featurestore_service_client_get_transport_class(): + transport = FeaturestoreServiceClient.get_transport_class() + available_transports = [ + transports.FeaturestoreServiceGrpcTransport, + ] + assert transport in available_transports + + transport = FeaturestoreServiceClient.get_transport_class("grpc") + assert transport == transports.FeaturestoreServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc"), + (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(FeaturestoreServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceClient)) +@mock.patch.object(FeaturestoreServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceAsyncClient)) +def test_featurestore_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(FeaturestoreServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(FeaturestoreServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + + (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc", "true"), + (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc", "false"), + (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(FeaturestoreServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceClient)) +@mock.patch.object(FeaturestoreServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_featurestore_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc"), + (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_featurestore_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc"), + (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_featurestore_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_featurestore_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = FeaturestoreServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_featurestore(transport: str = 'grpc', request_type=featurestore_service.CreateFeaturestoreRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateFeaturestoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_featurestore_from_dict(): + test_create_featurestore(request_type=dict) + + +def test_create_featurestore_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_featurestore), + '__call__') as call: + client.create_featurestore() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateFeaturestoreRequest() + +@pytest.mark.asyncio +async def test_create_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.CreateFeaturestoreRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateFeaturestoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_featurestore_async_from_dict(): + await test_create_featurestore_async(request_type=dict) + + +def test_create_featurestore_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.CreateFeaturestoreRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_featurestore), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_featurestore_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.CreateFeaturestoreRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_featurestore), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_featurestore_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_featurestore( + parent='parent_value', + featurestore=gca_featurestore.Featurestore(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + + +def test_create_featurestore_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_featurestore( + featurestore_service.CreateFeaturestoreRequest(), + parent='parent_value', + featurestore=gca_featurestore.Featurestore(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_featurestore_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_featurestore( + parent='parent_value', + featurestore=gca_featurestore.Featurestore(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + + +@pytest.mark.asyncio +async def test_create_featurestore_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_featurestore( + featurestore_service.CreateFeaturestoreRequest(), + parent='parent_value', + featurestore=gca_featurestore.Featurestore(name='name_value'), + ) + + +def test_get_featurestore(transport: str = 'grpc', request_type=featurestore_service.GetFeaturestoreRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore.Featurestore( + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + + state=featurestore.Featurestore.State.STABLE, + + ) + + response = client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetFeaturestoreRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, featurestore.Featurestore) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.state == featurestore.Featurestore.State.STABLE + + +def test_get_featurestore_from_dict(): + test_get_featurestore(request_type=dict) + + +def test_get_featurestore_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_featurestore), + '__call__') as call: + client.get_featurestore() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetFeaturestoreRequest() + +@pytest.mark.asyncio +async def test_get_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.GetFeaturestoreRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore.Featurestore( + name='name_value', + display_name='display_name_value', + etag='etag_value', + state=featurestore.Featurestore.State.STABLE, + )) + + response = await client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetFeaturestoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, featurestore.Featurestore) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + assert response.state == featurestore.Featurestore.State.STABLE + + +@pytest.mark.asyncio +async def test_get_featurestore_async_from_dict(): + await test_get_featurestore_async(request_type=dict) + + +def test_get_featurestore_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.GetFeaturestoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_featurestore), + '__call__') as call: + call.return_value = featurestore.Featurestore() + + client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_featurestore_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.GetFeaturestoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_featurestore), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore.Featurestore()) + + await client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_featurestore_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore.Featurestore() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_featurestore( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_featurestore_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_featurestore( + featurestore_service.GetFeaturestoreRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_featurestore_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore.Featurestore() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore.Featurestore()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_featurestore( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_featurestore_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_featurestore( + featurestore_service.GetFeaturestoreRequest(), + name='name_value', + ) + + +def test_list_featurestores(transport: str = 'grpc', request_type=featurestore_service.ListFeaturestoresRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListFeaturestoresResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListFeaturestoresRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListFeaturestoresPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_featurestores_from_dict(): + test_list_featurestores(request_type=dict) + + +def test_list_featurestores_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + client.list_featurestores() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListFeaturestoresRequest() + +@pytest.mark.asyncio +async def test_list_featurestores_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ListFeaturestoresRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturestoresResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListFeaturestoresRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListFeaturestoresAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_featurestores_async_from_dict(): + await test_list_featurestores_async(request_type=dict) + + +def test_list_featurestores_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ListFeaturestoresRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + call.return_value = featurestore_service.ListFeaturestoresResponse() + + client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_featurestores_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ListFeaturestoresRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturestoresResponse()) + + await client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_featurestores_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListFeaturestoresResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_featurestores( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_featurestores_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_featurestores( + featurestore_service.ListFeaturestoresRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_featurestores_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListFeaturestoresResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturestoresResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_featurestores( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_featurestores_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_featurestores( + featurestore_service.ListFeaturestoresRequest(), + parent='parent_value', + ) + + +def test_list_featurestores_pager(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[], + next_page_token='def', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_featurestores(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, featurestore.Featurestore) + for i in results) + +def test_list_featurestores_pages(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[], + next_page_token='def', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + ), + RuntimeError, + ) + pages = list(client.list_featurestores(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_featurestores_async_pager(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[], + next_page_token='def', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_featurestores(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, featurestore.Featurestore) + for i in responses) + +@pytest.mark.asyncio +async def test_list_featurestores_async_pages(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_featurestores), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[], + next_page_token='def', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturestoresResponse( + featurestores=[ + featurestore.Featurestore(), + featurestore.Featurestore(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_featurestores(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_featurestore(transport: str = 'grpc', request_type=featurestore_service.UpdateFeaturestoreRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateFeaturestoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_featurestore_from_dict(): + test_update_featurestore(request_type=dict) + + +def test_update_featurestore_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_featurestore), + '__call__') as call: + client.update_featurestore() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateFeaturestoreRequest() + +@pytest.mark.asyncio +async def test_update_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.UpdateFeaturestoreRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateFeaturestoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_featurestore_async_from_dict(): + await test_update_featurestore_async(request_type=dict) + + +def test_update_featurestore_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.UpdateFeaturestoreRequest() + request.featurestore.name = 'featurestore.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_featurestore), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'featurestore.name=featurestore.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_featurestore_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.UpdateFeaturestoreRequest() + request.featurestore.name = 'featurestore.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_featurestore), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'featurestore.name=featurestore.name/value', + ) in kw['metadata'] + + +def test_update_featurestore_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_featurestore( + featurestore=gca_featurestore.Featurestore(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_featurestore_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_featurestore( + featurestore_service.UpdateFeaturestoreRequest(), + featurestore=gca_featurestore.Featurestore(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_featurestore_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_featurestore( + featurestore=gca_featurestore.Featurestore(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_featurestore_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_featurestore( + featurestore_service.UpdateFeaturestoreRequest(), + featurestore=gca_featurestore.Featurestore(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_featurestore(transport: str = 'grpc', request_type=featurestore_service.DeleteFeaturestoreRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteFeaturestoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_featurestore_from_dict(): + test_delete_featurestore(request_type=dict) + + +def test_delete_featurestore_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_featurestore), + '__call__') as call: + client.delete_featurestore() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteFeaturestoreRequest() + +@pytest.mark.asyncio +async def test_delete_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.DeleteFeaturestoreRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteFeaturestoreRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_featurestore_async_from_dict(): + await test_delete_featurestore_async(request_type=dict) + + +def test_delete_featurestore_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.DeleteFeaturestoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_featurestore), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_featurestore_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.DeleteFeaturestoreRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_featurestore), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_featurestore_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_featurestore( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_featurestore_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_featurestore( + featurestore_service.DeleteFeaturestoreRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_featurestore_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_featurestore), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_featurestore( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_featurestore_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_featurestore( + featurestore_service.DeleteFeaturestoreRequest(), + name='name_value', + ) + + +def test_create_entity_type(transport: str = 'grpc', request_type=featurestore_service.CreateEntityTypeRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateEntityTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_entity_type_from_dict(): + test_create_entity_type(request_type=dict) + + +def test_create_entity_type_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_entity_type), + '__call__') as call: + client.create_entity_type() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateEntityTypeRequest() + +@pytest.mark.asyncio +async def test_create_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.CreateEntityTypeRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateEntityTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_entity_type_async_from_dict(): + await test_create_entity_type_async(request_type=dict) + + +def test_create_entity_type_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.CreateEntityTypeRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_entity_type), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_entity_type_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.CreateEntityTypeRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_entity_type), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_entity_type_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_entity_type( + parent='parent_value', + entity_type=gca_entity_type.EntityType(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + + +def test_create_entity_type_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_entity_type( + featurestore_service.CreateEntityTypeRequest(), + parent='parent_value', + entity_type=gca_entity_type.EntityType(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_entity_type_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_entity_type( + parent='parent_value', + entity_type=gca_entity_type.EntityType(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + + +@pytest.mark.asyncio +async def test_create_entity_type_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_entity_type( + featurestore_service.CreateEntityTypeRequest(), + parent='parent_value', + entity_type=gca_entity_type.EntityType(name='name_value'), + ) + + +def test_get_entity_type(transport: str = 'grpc', request_type=featurestore_service.GetEntityTypeRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = entity_type.EntityType( + name='name_value', + + description='description_value', + + etag='etag_value', + + ) + + response = client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetEntityTypeRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, entity_type.EntityType) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +def test_get_entity_type_from_dict(): + test_get_entity_type(request_type=dict) + + +def test_get_entity_type_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_entity_type), + '__call__') as call: + client.get_entity_type() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetEntityTypeRequest() + +@pytest.mark.asyncio +async def test_get_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.GetEntityTypeRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(entity_type.EntityType( + name='name_value', + description='description_value', + etag='etag_value', + )) + + response = await client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetEntityTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, entity_type.EntityType) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_get_entity_type_async_from_dict(): + await test_get_entity_type_async(request_type=dict) + + +def test_get_entity_type_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.GetEntityTypeRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_entity_type), + '__call__') as call: + call.return_value = entity_type.EntityType() + + client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_entity_type_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.GetEntityTypeRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_entity_type), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(entity_type.EntityType()) + + await client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_entity_type_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = entity_type.EntityType() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_entity_type( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_entity_type_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_entity_type( + featurestore_service.GetEntityTypeRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_entity_type_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = entity_type.EntityType() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(entity_type.EntityType()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_entity_type( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_entity_type_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_entity_type( + featurestore_service.GetEntityTypeRequest(), + name='name_value', + ) + + +def test_list_entity_types(transport: str = 'grpc', request_type=featurestore_service.ListEntityTypesRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListEntityTypesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListEntityTypesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListEntityTypesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_entity_types_from_dict(): + test_list_entity_types(request_type=dict) + + +def test_list_entity_types_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + client.list_entity_types() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListEntityTypesRequest() + +@pytest.mark.asyncio +async def test_list_entity_types_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ListEntityTypesRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListEntityTypesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListEntityTypesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListEntityTypesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_entity_types_async_from_dict(): + await test_list_entity_types_async(request_type=dict) + + +def test_list_entity_types_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ListEntityTypesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + call.return_value = featurestore_service.ListEntityTypesResponse() + + client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_entity_types_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ListEntityTypesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListEntityTypesResponse()) + + await client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_entity_types_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListEntityTypesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_entity_types( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_entity_types_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_entity_types( + featurestore_service.ListEntityTypesRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_entity_types_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListEntityTypesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListEntityTypesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_entity_types( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_entity_types_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_entity_types( + featurestore_service.ListEntityTypesRequest(), + parent='parent_value', + ) + + +def test_list_entity_types_pager(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + entity_type.EntityType(), + ], + next_page_token='abc', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[], + next_page_token='def', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + ], + next_page_token='ghi', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_entity_types(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, entity_type.EntityType) + for i in results) + +def test_list_entity_types_pages(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + entity_type.EntityType(), + ], + next_page_token='abc', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[], + next_page_token='def', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + ], + next_page_token='ghi', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + ], + ), + RuntimeError, + ) + pages = list(client.list_entity_types(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_entity_types_async_pager(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + entity_type.EntityType(), + ], + next_page_token='abc', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[], + next_page_token='def', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + ], + next_page_token='ghi', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_entity_types(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, entity_type.EntityType) + for i in responses) + +@pytest.mark.asyncio +async def test_list_entity_types_async_pages(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_entity_types), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + entity_type.EntityType(), + ], + next_page_token='abc', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[], + next_page_token='def', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + ], + next_page_token='ghi', + ), + featurestore_service.ListEntityTypesResponse( + entity_types=[ + entity_type.EntityType(), + entity_type.EntityType(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_entity_types(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_entity_type(transport: str = 'grpc', request_type=featurestore_service.UpdateEntityTypeRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_entity_type.EntityType( + name='name_value', + + description='description_value', + + etag='etag_value', + + ) + + response = client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateEntityTypeRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_entity_type.EntityType) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +def test_update_entity_type_from_dict(): + test_update_entity_type(request_type=dict) + + +def test_update_entity_type_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_entity_type), + '__call__') as call: + client.update_entity_type() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateEntityTypeRequest() + +@pytest.mark.asyncio +async def test_update_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.UpdateEntityTypeRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_entity_type.EntityType( + name='name_value', + description='description_value', + etag='etag_value', + )) + + response = await client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateEntityTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_entity_type.EntityType) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_update_entity_type_async_from_dict(): + await test_update_entity_type_async(request_type=dict) + + +def test_update_entity_type_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.UpdateEntityTypeRequest() + request.entity_type.name = 'entity_type.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_entity_type), + '__call__') as call: + call.return_value = gca_entity_type.EntityType() + + client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type.name=entity_type.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_entity_type_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.UpdateEntityTypeRequest() + request.entity_type.name = 'entity_type.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_entity_type), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_entity_type.EntityType()) + + await client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type.name=entity_type.name/value', + ) in kw['metadata'] + + +def test_update_entity_type_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_entity_type.EntityType() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_entity_type( + entity_type=gca_entity_type.EntityType(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_entity_type_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_entity_type( + featurestore_service.UpdateEntityTypeRequest(), + entity_type=gca_entity_type.EntityType(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_entity_type_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_entity_type.EntityType() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_entity_type.EntityType()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_entity_type( + entity_type=gca_entity_type.EntityType(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_entity_type_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_entity_type( + featurestore_service.UpdateEntityTypeRequest(), + entity_type=gca_entity_type.EntityType(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_entity_type(transport: str = 'grpc', request_type=featurestore_service.DeleteEntityTypeRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteEntityTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_entity_type_from_dict(): + test_delete_entity_type(request_type=dict) + + +def test_delete_entity_type_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_entity_type), + '__call__') as call: + client.delete_entity_type() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteEntityTypeRequest() + +@pytest.mark.asyncio +async def test_delete_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.DeleteEntityTypeRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteEntityTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_entity_type_async_from_dict(): + await test_delete_entity_type_async(request_type=dict) + + +def test_delete_entity_type_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.DeleteEntityTypeRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_entity_type), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_entity_type_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.DeleteEntityTypeRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_entity_type), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_entity_type_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_entity_type( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_entity_type_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_entity_type( + featurestore_service.DeleteEntityTypeRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_entity_type_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_entity_type), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_entity_type( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_entity_type_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_entity_type( + featurestore_service.DeleteEntityTypeRequest(), + name='name_value', + ) + + +def test_create_feature(transport: str = 'grpc', request_type=featurestore_service.CreateFeatureRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateFeatureRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_feature_from_dict(): + test_create_feature(request_type=dict) + + +def test_create_feature_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_feature), + '__call__') as call: + client.create_feature() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateFeatureRequest() + +@pytest.mark.asyncio +async def test_create_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.CreateFeatureRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.CreateFeatureRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_feature_async_from_dict(): + await test_create_feature_async(request_type=dict) + + +def test_create_feature_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.CreateFeatureRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_feature), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_feature_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.CreateFeatureRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_feature), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_feature_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_feature( + parent='parent_value', + feature=gca_feature.Feature(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].feature == gca_feature.Feature(name='name_value') + + +def test_create_feature_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_feature( + featurestore_service.CreateFeatureRequest(), + parent='parent_value', + feature=gca_feature.Feature(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_feature_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_feature( + parent='parent_value', + feature=gca_feature.Feature(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].feature == gca_feature.Feature(name='name_value') + + +@pytest.mark.asyncio +async def test_create_feature_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_feature( + featurestore_service.CreateFeatureRequest(), + parent='parent_value', + feature=gca_feature.Feature(name='name_value'), + ) + + +def test_batch_create_features(transport: str = 'grpc', request_type=featurestore_service.BatchCreateFeaturesRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.BatchCreateFeaturesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_batch_create_features_from_dict(): + test_batch_create_features(request_type=dict) + + +def test_batch_create_features_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_features), + '__call__') as call: + client.batch_create_features() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.BatchCreateFeaturesRequest() + +@pytest.mark.asyncio +async def test_batch_create_features_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.BatchCreateFeaturesRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.BatchCreateFeaturesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_batch_create_features_async_from_dict(): + await test_batch_create_features_async(request_type=dict) + + +def test_batch_create_features_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.BatchCreateFeaturesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_features), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_batch_create_features_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.BatchCreateFeaturesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_features), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_batch_create_features_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_create_features( + parent='parent_value', + requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].requests == [featurestore_service.CreateFeatureRequest(parent='parent_value')] + + +def test_batch_create_features_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_create_features( + featurestore_service.BatchCreateFeaturesRequest(), + parent='parent_value', + requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + ) + + +@pytest.mark.asyncio +async def test_batch_create_features_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_create_features( + parent='parent_value', + requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].requests == [featurestore_service.CreateFeatureRequest(parent='parent_value')] + + +@pytest.mark.asyncio +async def test_batch_create_features_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_create_features( + featurestore_service.BatchCreateFeaturesRequest(), + parent='parent_value', + requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + ) + + +def test_get_feature(transport: str = 'grpc', request_type=featurestore_service.GetFeatureRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = feature.Feature( + name='name_value', + + description='description_value', + + value_type=feature.Feature.ValueType.BOOL, + + etag='etag_value', + + ) + + response = client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetFeatureRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, feature.Feature) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.value_type == feature.Feature.ValueType.BOOL + + assert response.etag == 'etag_value' + + +def test_get_feature_from_dict(): + test_get_feature(request_type=dict) + + +def test_get_feature_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_feature), + '__call__') as call: + client.get_feature() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetFeatureRequest() + +@pytest.mark.asyncio +async def test_get_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.GetFeatureRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(feature.Feature( + name='name_value', + description='description_value', + value_type=feature.Feature.ValueType.BOOL, + etag='etag_value', + )) + + response = await client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.GetFeatureRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, feature.Feature) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.value_type == feature.Feature.ValueType.BOOL + + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_get_feature_async_from_dict(): + await test_get_feature_async(request_type=dict) + + +def test_get_feature_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.GetFeatureRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_feature), + '__call__') as call: + call.return_value = feature.Feature() + + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_feature_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.GetFeatureRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_feature), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(feature.Feature()) + + await client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_feature_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = feature.Feature() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_feature( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_feature_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_feature( + featurestore_service.GetFeatureRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_feature_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = feature.Feature() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(feature.Feature()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_feature( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_feature_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_feature( + featurestore_service.GetFeatureRequest(), + name='name_value', + ) + + +def test_list_features(transport: str = 'grpc', request_type=featurestore_service.ListFeaturesRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListFeaturesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListFeaturesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListFeaturesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_features_from_dict(): + test_list_features(request_type=dict) + + +def test_list_features_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + client.list_features() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListFeaturesRequest() + +@pytest.mark.asyncio +async def test_list_features_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ListFeaturesRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ListFeaturesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListFeaturesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_features_async_from_dict(): + await test_list_features_async(request_type=dict) + + +def test_list_features_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ListFeaturesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + call.return_value = featurestore_service.ListFeaturesResponse() + + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_features_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ListFeaturesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturesResponse()) + + await client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_features_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListFeaturesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_features( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_features_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_features( + featurestore_service.ListFeaturesRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_features_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.ListFeaturesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_features( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_features_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_features( + featurestore_service.ListFeaturesRequest(), + parent='parent_value', + ) + + +def test_list_features_pager(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_features(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, feature.Feature) + for i in results) + +def test_list_features_pages(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + pages = list(client.list_features(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_features_async_pager(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_features(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, feature.Feature) + for i in responses) + +@pytest.mark.asyncio +async def test_list_features_async_pages(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_features), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.ListFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.ListFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_features(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_feature(transport: str = 'grpc', request_type=featurestore_service.UpdateFeatureRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_feature.Feature( + name='name_value', + + description='description_value', + + value_type=gca_feature.Feature.ValueType.BOOL, + + etag='etag_value', + + ) + + response = client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateFeatureRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_feature.Feature) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.value_type == gca_feature.Feature.ValueType.BOOL + + assert response.etag == 'etag_value' + + +def test_update_feature_from_dict(): + test_update_feature(request_type=dict) + + +def test_update_feature_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_feature), + '__call__') as call: + client.update_feature() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateFeatureRequest() + +@pytest.mark.asyncio +async def test_update_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.UpdateFeatureRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_feature.Feature( + name='name_value', + description='description_value', + value_type=gca_feature.Feature.ValueType.BOOL, + etag='etag_value', + )) + + response = await client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.UpdateFeatureRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_feature.Feature) + + assert response.name == 'name_value' + + assert response.description == 'description_value' + + assert response.value_type == gca_feature.Feature.ValueType.BOOL + + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_update_feature_async_from_dict(): + await test_update_feature_async(request_type=dict) + + +def test_update_feature_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.UpdateFeatureRequest() + request.feature.name = 'feature.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_feature), + '__call__') as call: + call.return_value = gca_feature.Feature() + + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'feature.name=feature.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_feature_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.UpdateFeatureRequest() + request.feature.name = 'feature.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_feature), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_feature.Feature()) + + await client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'feature.name=feature.name/value', + ) in kw['metadata'] + + +def test_update_feature_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_feature.Feature() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_feature( + feature=gca_feature.Feature(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].feature == gca_feature.Feature(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_feature_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_feature( + featurestore_service.UpdateFeatureRequest(), + feature=gca_feature.Feature(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_feature_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_feature.Feature() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_feature.Feature()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_feature( + feature=gca_feature.Feature(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].feature == gca_feature.Feature(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_feature_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_feature( + featurestore_service.UpdateFeatureRequest(), + feature=gca_feature.Feature(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_feature(transport: str = 'grpc', request_type=featurestore_service.DeleteFeatureRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteFeatureRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_feature_from_dict(): + test_delete_feature(request_type=dict) + + +def test_delete_feature_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_feature), + '__call__') as call: + client.delete_feature() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteFeatureRequest() + +@pytest.mark.asyncio +async def test_delete_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.DeleteFeatureRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.DeleteFeatureRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_feature_async_from_dict(): + await test_delete_feature_async(request_type=dict) + + +def test_delete_feature_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.DeleteFeatureRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_feature), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_feature_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.DeleteFeatureRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_feature), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_feature_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_feature( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_feature_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_feature( + featurestore_service.DeleteFeatureRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_feature_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_feature), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_feature( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_feature_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_feature( + featurestore_service.DeleteFeatureRequest(), + name='name_value', + ) + + +def test_import_feature_values(transport: str = 'grpc', request_type=featurestore_service.ImportFeatureValuesRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ImportFeatureValuesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_import_feature_values_from_dict(): + test_import_feature_values(request_type=dict) + + +def test_import_feature_values_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_feature_values), + '__call__') as call: + client.import_feature_values() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ImportFeatureValuesRequest() + +@pytest.mark.asyncio +async def test_import_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ImportFeatureValuesRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ImportFeatureValuesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_import_feature_values_async_from_dict(): + await test_import_feature_values_async(request_type=dict) + + +def test_import_feature_values_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ImportFeatureValuesRequest() + request.entity_type = 'entity_type/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_feature_values), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type=entity_type/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_import_feature_values_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ImportFeatureValuesRequest() + request.entity_type = 'entity_type/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_feature_values), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'entity_type=entity_type/value', + ) in kw['metadata'] + + +def test_import_feature_values_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.import_feature_values( + entity_type='entity_type_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == 'entity_type_value' + + +def test_import_feature_values_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.import_feature_values( + featurestore_service.ImportFeatureValuesRequest(), + entity_type='entity_type_value', + ) + + +@pytest.mark.asyncio +async def test_import_feature_values_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.import_feature_values( + entity_type='entity_type_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == 'entity_type_value' + + +@pytest.mark.asyncio +async def test_import_feature_values_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.import_feature_values( + featurestore_service.ImportFeatureValuesRequest(), + entity_type='entity_type_value', + ) + + +def test_batch_read_feature_values(transport: str = 'grpc', request_type=featurestore_service.BatchReadFeatureValuesRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.BatchReadFeatureValuesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_batch_read_feature_values_from_dict(): + test_batch_read_feature_values(request_type=dict) + + +def test_batch_read_feature_values_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_read_feature_values), + '__call__') as call: + client.batch_read_feature_values() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.BatchReadFeatureValuesRequest() + +@pytest.mark.asyncio +async def test_batch_read_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.BatchReadFeatureValuesRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.BatchReadFeatureValuesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_batch_read_feature_values_async_from_dict(): + await test_batch_read_feature_values_async(request_type=dict) + + +def test_batch_read_feature_values_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.BatchReadFeatureValuesRequest() + request.featurestore = 'featurestore/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_read_feature_values), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'featurestore=featurestore/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_batch_read_feature_values_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.BatchReadFeatureValuesRequest() + request.featurestore = 'featurestore/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_read_feature_values), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'featurestore=featurestore/value', + ) in kw['metadata'] + + +def test_batch_read_feature_values_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_read_feature_values( + featurestore='featurestore_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].featurestore == 'featurestore_value' + + +def test_batch_read_feature_values_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_read_feature_values( + featurestore_service.BatchReadFeatureValuesRequest(), + featurestore='featurestore_value', + ) + + +@pytest.mark.asyncio +async def test_batch_read_feature_values_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_read_feature_values), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_read_feature_values( + featurestore='featurestore_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].featurestore == 'featurestore_value' + + +@pytest.mark.asyncio +async def test_batch_read_feature_values_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_read_feature_values( + featurestore_service.BatchReadFeatureValuesRequest(), + featurestore='featurestore_value', + ) + + +def test_search_features(transport: str = 'grpc', request_type=featurestore_service.SearchFeaturesRequest): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.SearchFeaturesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.SearchFeaturesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.SearchFeaturesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_search_features_from_dict(): + test_search_features(request_type=dict) + + +def test_search_features_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + client.search_features() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.SearchFeaturesRequest() + +@pytest.mark.asyncio +async def test_search_features_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.SearchFeaturesRequest): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.SearchFeaturesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.SearchFeaturesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.SearchFeaturesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_search_features_async_from_dict(): + await test_search_features_async(request_type=dict) + + +def test_search_features_field_headers(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.SearchFeaturesRequest() + request.location = 'location/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + call.return_value = featurestore_service.SearchFeaturesResponse() + + client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'location=location/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_search_features_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.SearchFeaturesRequest() + request.location = 'location/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.SearchFeaturesResponse()) + + await client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'location=location/value', + ) in kw['metadata'] + + +def test_search_features_flattened(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.SearchFeaturesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.search_features( + location='location_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].location == 'location_value' + + +def test_search_features_flattened_error(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.search_features( + featurestore_service.SearchFeaturesRequest(), + location='location_value', + ) + + +@pytest.mark.asyncio +async def test_search_features_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = featurestore_service.SearchFeaturesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.SearchFeaturesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.search_features( + location='location_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].location == 'location_value' + + +@pytest.mark.asyncio +async def test_search_features_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.search_features( + featurestore_service.SearchFeaturesRequest(), + location='location_value', + ) + + +def test_search_features_pager(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.SearchFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('location', ''), + )), + ) + pager = client.search_features(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, feature.Feature) + for i in results) + +def test_search_features_pages(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.SearchFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + pages = list(client.search_features(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_search_features_async_pager(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.SearchFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + async_pager = await client.search_features(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, feature.Feature) + for i in responses) + +@pytest.mark.asyncio +async def test_search_features_async_pages(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_features), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + feature.Feature(), + ], + next_page_token='abc', + ), + featurestore_service.SearchFeaturesResponse( + features=[], + next_page_token='def', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + ], + next_page_token='ghi', + ), + featurestore_service.SearchFeaturesResponse( + features=[ + feature.Feature(), + feature.Feature(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.search_features(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.FeaturestoreServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.FeaturestoreServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FeaturestoreServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.FeaturestoreServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FeaturestoreServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.FeaturestoreServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = FeaturestoreServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.FeaturestoreServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.FeaturestoreServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.FeaturestoreServiceGrpcTransport, + transports.FeaturestoreServiceGrpcAsyncIOTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.FeaturestoreServiceGrpcTransport, + ) + + +def test_featurestore_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.FeaturestoreServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_featurestore_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.FeaturestoreServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_featurestore', + 'get_featurestore', + 'list_featurestores', + 'update_featurestore', + 'delete_featurestore', + 'create_entity_type', + 'get_entity_type', + 'list_entity_types', + 'update_entity_type', + 'delete_entity_type', + 'create_feature', + 'batch_create_features', + 'get_feature', + 'list_features', + 'update_feature', + 'delete_feature', + 'import_feature_values', + 'batch_read_feature_values', + 'search_features', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_featurestore_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.FeaturestoreServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_featurestore_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.FeaturestoreServiceTransport() + adc.assert_called_once() + + +def test_featurestore_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + FeaturestoreServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_featurestore_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.FeaturestoreServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize("transport_class", [transports.FeaturestoreServiceGrpcTransport, transports.FeaturestoreServiceGrpcAsyncIOTransport]) +def test_featurestore_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + + +def test_featurestore_service_host_no_port(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_featurestore_service_host_with_port(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_featurestore_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.FeaturestoreServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_featurestore_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.FeaturestoreServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.FeaturestoreServiceGrpcTransport, transports.FeaturestoreServiceGrpcAsyncIOTransport]) +def test_featurestore_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.FeaturestoreServiceGrpcTransport, transports.FeaturestoreServiceGrpcAsyncIOTransport]) +def test_featurestore_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_featurestore_service_grpc_lro_client(): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_featurestore_service_grpc_lro_async_client(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_entity_type_path(): + project = "squid" + location = "clam" + featurestore = "whelk" + entity_type = "octopus" + + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) + actual = FeaturestoreServiceClient.entity_type_path(project, location, featurestore, entity_type) + assert expected == actual + + +def test_parse_entity_type_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "featurestore": "cuttlefish", + "entity_type": "mussel", + + } + path = FeaturestoreServiceClient.entity_type_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_entity_type_path(path) + assert expected == actual + +def test_feature_path(): + project = "winkle" + location = "nautilus" + featurestore = "scallop" + entity_type = "abalone" + feature = "squid" + + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, feature=feature, ) + actual = FeaturestoreServiceClient.feature_path(project, location, featurestore, entity_type, feature) + assert expected == actual + + +def test_parse_feature_path(): + expected = { + "project": "clam", + "location": "whelk", + "featurestore": "octopus", + "entity_type": "oyster", + "feature": "nudibranch", + + } + path = FeaturestoreServiceClient.feature_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_feature_path(path) + assert expected == actual + +def test_featurestore_path(): + project = "cuttlefish" + location = "mussel" + featurestore = "winkle" + + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}".format(project=project, location=location, featurestore=featurestore, ) + actual = FeaturestoreServiceClient.featurestore_path(project, location, featurestore) + assert expected == actual + + +def test_parse_featurestore_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "featurestore": "abalone", + + } + path = FeaturestoreServiceClient.featurestore_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_featurestore_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = FeaturestoreServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + + } + path = FeaturestoreServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder, ) + actual = FeaturestoreServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + + } + path = FeaturestoreServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = FeaturestoreServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + + } + path = FeaturestoreServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project, ) + actual = FeaturestoreServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + + } + path = FeaturestoreServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = FeaturestoreServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + + } + path = FeaturestoreServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = FeaturestoreServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.FeaturestoreServiceTransport, '_prep_wrapped_messages') as prep: + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.FeaturestoreServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = FeaturestoreServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) From 2fa7e504e74cd419e2b696f1202476c802eba9e4 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Mon, 12 Apr 2021 17:47:10 -0400 Subject: [PATCH 06/36] test --- google/cloud/aiplatform_v1beta1/__init__.py | 2 + .../services/metadata_service/async_client.py | 87 +++++++ .../services/metadata_service/client.py | 88 +++++++ .../metadata_service/transports/base.py | 14 ++ .../metadata_service/transports/grpc.py | 29 +++ .../transports/grpc_asyncio.py | 29 +++ .../services/migration_service/client.py | 12 +- .../aiplatform_v1beta1/types/__init__.py | 2 + .../types/metadata_service.py | 27 ++ .../test_metadata_service.py | 236 ++++++++++++++++++ .../test_migration_service.py | 36 +-- 11 files changed, 538 insertions(+), 24 deletions(-) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 2936282360..0ec5663b24 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -233,6 +233,7 @@ from .types.metadata_service import ListMetadataSchemasResponse from .types.metadata_service import ListMetadataStoresRequest from .types.metadata_service import ListMetadataStoresResponse +from .types.metadata_service import QueryArtifactLineageSubgraphRequest from .types.metadata_service import QueryContextLineageSubgraphRequest from .types.metadata_service import QueryExecutionInputsAndOutputsRequest from .types.metadata_service import UpdateArtifactRequest @@ -607,6 +608,7 @@ 'PredictSchemata', 'PredictionServiceClient', 'PythonPackageSpec', + 'QueryArtifactLineageSubgraphRequest', 'QueryContextLineageSubgraphRequest', 'QueryExecutionInputsAndOutputsRequest', 'ReadFeatureValuesRequest', diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index 5b5275ba33..f42523d3f2 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -2466,6 +2466,93 @@ async def list_metadata_schemas(self, # Done; return the response. return response + async def query_artifact_lineage_subgraph(self, + request: metadata_service.QueryArtifactLineageSubgraphRequest = None, + *, + artifact: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: + r"""Retrieves lineage of an Artifact represented through + Artifacts and Executions connected by Event edges and + returned as a LineageSubgraph. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.QueryArtifactLineageSubgraphRequest`): + The request object. Request message for + ``MetadataService.QueryArtifactLineageSubgraph``. + artifact (:class:`str`): + Required. The resource name of the Artifact whose + Lineage needs to be retrieved as a LineageSubgraph. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + The request may error with FAILED_PRECONDITION if the + number of Artifacts, the number of Executions, or the + number of Events that would be returned for the Context + exceeds 1000. + + This corresponds to the ``artifact`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.LineageSubgraph: + A subgraph of the overall lineage + graph. Event edges connect Artifact and + Execution nodes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([artifact]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = metadata_service.QueryArtifactLineageSubgraphRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if artifact is not None: + request.artifact = artifact + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.query_artifact_lineage_subgraph, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('artifact', request.artifact), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index 06ca29cf5a..6983d6e5fd 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -2696,6 +2696,94 @@ def list_metadata_schemas(self, # Done; return the response. return response + def query_artifact_lineage_subgraph(self, + request: metadata_service.QueryArtifactLineageSubgraphRequest = None, + *, + artifact: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: + r"""Retrieves lineage of an Artifact represented through + Artifacts and Executions connected by Event edges and + returned as a LineageSubgraph. + + Args: + request (google.cloud.aiplatform_v1beta1.types.QueryArtifactLineageSubgraphRequest): + The request object. Request message for + ``MetadataService.QueryArtifactLineageSubgraph``. + artifact (str): + Required. The resource name of the Artifact whose + Lineage needs to be retrieved as a LineageSubgraph. + Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + The request may error with FAILED_PRECONDITION if the + number of Artifacts, the number of Executions, or the + number of Events that would be returned for the Context + exceeds 1000. + + This corresponds to the ``artifact`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.LineageSubgraph: + A subgraph of the overall lineage + graph. Event edges connect Artifact and + Execution nodes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([artifact]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.QueryArtifactLineageSubgraphRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.QueryArtifactLineageSubgraphRequest): + request = metadata_service.QueryArtifactLineageSubgraphRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if artifact is not None: + request.artifact = artifact + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.query_artifact_lineage_subgraph] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('artifact', request.artifact), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py index 76ef934c98..f4acfb6800 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py @@ -241,6 +241,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.query_artifact_lineage_subgraph: gapic_v1.method.wrap_method( + self.query_artifact_lineage_subgraph, + default_timeout=None, + client_info=client_info, + ), } @@ -474,6 +479,15 @@ def list_metadata_schemas(self) -> typing.Callable[ ]]: raise NotImplementedError() + @property + def query_artifact_lineage_subgraph(self) -> typing.Callable[ + [metadata_service.QueryArtifactLineageSubgraphRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph] + ]]: + raise NotImplementedError() + __all__ = ( 'MetadataServiceTransport', diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py index 7cc6484f91..12ca2e4cc2 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py @@ -911,6 +911,35 @@ def list_metadata_schemas(self) -> Callable[ ) return self._stubs['list_metadata_schemas'] + @property + def query_artifact_lineage_subgraph(self) -> Callable[ + [metadata_service.QueryArtifactLineageSubgraphRequest], + lineage_subgraph.LineageSubgraph]: + r"""Return a callable for the query artifact lineage + subgraph method over gRPC. + + Retrieves lineage of an Artifact represented through + Artifacts and Executions connected by Event edges and + returned as a LineageSubgraph. + + Returns: + Callable[[~.QueryArtifactLineageSubgraphRequest], + ~.LineageSubgraph]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'query_artifact_lineage_subgraph' not in self._stubs: + self._stubs['query_artifact_lineage_subgraph'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/QueryArtifactLineageSubgraph', + request_serializer=metadata_service.QueryArtifactLineageSubgraphRequest.serialize, + response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, + ) + return self._stubs['query_artifact_lineage_subgraph'] + __all__ = ( 'MetadataServiceGrpcTransport', diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py index bedea761c0..083f379def 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py @@ -916,6 +916,35 @@ def list_metadata_schemas(self) -> Callable[ ) return self._stubs['list_metadata_schemas'] + @property + def query_artifact_lineage_subgraph(self) -> Callable[ + [metadata_service.QueryArtifactLineageSubgraphRequest], + Awaitable[lineage_subgraph.LineageSubgraph]]: + r"""Return a callable for the query artifact lineage + subgraph method over gRPC. + + Retrieves lineage of an Artifact represented through + Artifacts and Executions connected by Event edges and + returned as a LineageSubgraph. + + Returns: + Callable[[~.QueryArtifactLineageSubgraphRequest], + Awaitable[~.LineageSubgraph]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'query_artifact_lineage_subgraph' not in self._stubs: + self._stubs['query_artifact_lineage_subgraph'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MetadataService/QueryArtifactLineageSubgraph', + request_serializer=metadata_service.QueryArtifactLineageSubgraphRequest.serialize, + response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, + ) + return self._stubs['query_artifact_lineage_subgraph'] + __all__ = ( 'MetadataServiceGrpcAsyncIOTransport', diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index a636962692..1c08ffef30 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -173,14 +173,14 @@ def parse_annotated_dataset_path(path: str) -> Dict[str,str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -195,14 +195,14 @@ def parse_dataset_path(path: str) -> Dict[str,str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index aff56b122d..17a2d7a221 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -290,6 +290,7 @@ ListMetadataSchemasResponse, ListMetadataStoresRequest, ListMetadataStoresResponse, + QueryArtifactLineageSubgraphRequest, QueryContextLineageSubgraphRequest, QueryExecutionInputsAndOutputsRequest, UpdateArtifactRequest, @@ -642,6 +643,7 @@ 'ListMetadataSchemasResponse', 'ListMetadataStoresRequest', 'ListMetadataStoresResponse', + 'QueryArtifactLineageSubgraphRequest', 'QueryContextLineageSubgraphRequest', 'QueryExecutionInputsAndOutputsRequest', 'UpdateArtifactRequest', diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_service.py b/google/cloud/aiplatform_v1beta1/types/metadata_service.py index 3777316237..96ceb992ad 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_service.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_service.py @@ -66,6 +66,7 @@ 'GetMetadataSchemaRequest', 'ListMetadataSchemasRequest', 'ListMetadataSchemasResponse', + 'QueryArtifactLineageSubgraphRequest', }, ) @@ -897,4 +898,30 @@ def raw_page(self): next_page_token = proto.Field(proto.STRING, number=2) +class QueryArtifactLineageSubgraphRequest(proto.Message): + r"""Request message for + ``MetadataService.QueryArtifactLineageSubgraph``. + + Attributes: + artifact (str): + Required. The resource name of the Artifact whose Lineage + needs to be retrieved as a LineageSubgraph. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + The request may error with FAILED_PRECONDITION if the number + of Artifacts, the number of Executions, or the number of + Events that would be returned for the Context exceeds 1000. + max_hops (int): + Specifies the size of the lineage graph in terms of number + of hops from the specified artifact. Negative Value: + INVALID_ARGUMENT error is returned 0: Only input artifact is + returned. No value: Transitive closure is performed to + return the complete graph. + """ + + artifact = proto.Field(proto.STRING, number=1) + + max_hops = proto.Field(proto.INT32, number=2) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index 0a71403d33..35b2de66b5 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -7893,6 +7893,241 @@ async def test_list_metadata_schemas_async_pages(): assert page_.raw_page.next_page_token == token +def test_query_artifact_lineage_subgraph(transport: str = 'grpc', request_type=metadata_service.QueryArtifactLineageSubgraphRequest): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_artifact_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph( + ) + + response = client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryArtifactLineageSubgraphRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, lineage_subgraph.LineageSubgraph) + + +def test_query_artifact_lineage_subgraph_from_dict(): + test_query_artifact_lineage_subgraph(request_type=dict) + + +def test_query_artifact_lineage_subgraph_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_artifact_lineage_subgraph), + '__call__') as call: + client.query_artifact_lineage_subgraph() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryArtifactLineageSubgraphRequest() + +@pytest.mark.asyncio +async def test_query_artifact_lineage_subgraph_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryArtifactLineageSubgraphRequest): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_artifact_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( + )) + + response = await client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == metadata_service.QueryArtifactLineageSubgraphRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, lineage_subgraph.LineageSubgraph) + + +@pytest.mark.asyncio +async def test_query_artifact_lineage_subgraph_async_from_dict(): + await test_query_artifact_lineage_subgraph_async(request_type=dict) + + +def test_query_artifact_lineage_subgraph_field_headers(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.QueryArtifactLineageSubgraphRequest() + request.artifact = 'artifact/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_artifact_lineage_subgraph), + '__call__') as call: + call.return_value = lineage_subgraph.LineageSubgraph() + + client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'artifact=artifact/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_query_artifact_lineage_subgraph_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.QueryArtifactLineageSubgraphRequest() + request.artifact = 'artifact/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_artifact_lineage_subgraph), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + + await client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'artifact=artifact/value', + ) in kw['metadata'] + + +def test_query_artifact_lineage_subgraph_flattened(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_artifact_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.query_artifact_lineage_subgraph( + artifact='artifact_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].artifact == 'artifact_value' + + +def test_query_artifact_lineage_subgraph_flattened_error(): + client = MetadataServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.query_artifact_lineage_subgraph( + metadata_service.QueryArtifactLineageSubgraphRequest(), + artifact='artifact_value', + ) + + +@pytest.mark.asyncio +async def test_query_artifact_lineage_subgraph_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_artifact_lineage_subgraph), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = lineage_subgraph.LineageSubgraph() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.query_artifact_lineage_subgraph( + artifact='artifact_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].artifact == 'artifact_value' + + +@pytest.mark.asyncio +async def test_query_artifact_lineage_subgraph_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.query_artifact_lineage_subgraph( + metadata_service.QueryArtifactLineageSubgraphRequest(), + artifact='artifact_value', + ) + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.MetadataServiceGrpcTransport( @@ -8017,6 +8252,7 @@ def test_metadata_service_base_transport(): 'create_metadata_schema', 'get_metadata_schema', 'list_metadata_schemas', + 'query_artifact_lineage_subgraph', ) for method in methods: with pytest.raises(NotImplementedError): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 85cf790381..134e0632c7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -1420,19 +1420,17 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - location = "mussel" - dataset = "winkle" + dataset = "mussel" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1442,9 +1440,9 @@ def test_parse_dataset_path(): assert expected == actual def test_dataset_path(): - project = "squid" - location = "clam" - dataset = "whelk" + project = "scallop" + location = "abalone" + dataset = "squid" expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = MigrationServiceClient.dataset_path(project, location, dataset) @@ -1453,9 +1451,9 @@ def test_dataset_path(): def test_parse_dataset_path(): expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", + "project": "clam", + "location": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1465,17 +1463,19 @@ def test_parse_dataset_path(): assert expected == actual def test_dataset_path(): - project = "cuttlefish" - dataset = "mussel" + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", + "project": "mussel", + "location": "winkle", "dataset": "nautilus", } From 5cdb9afa314de5c1d81cce466322084d17879d28 Mon Sep 17 00:00:00 2001 From: "Ignacio (Nacho) Cano" Date: Tue, 13 Apr 2021 08:23:38 -0700 Subject: [PATCH 07/36] feat: adding MetadataService API and new classes for Context, Artifact and Execution (#303) --- google/cloud/aiplatform/__init__.py | 17 + google/cloud/aiplatform/initializer.py | 21 +- google/cloud/aiplatform/metadata/artifact.py | 59 +++ google/cloud/aiplatform/metadata/constants.py | 28 ++ google/cloud/aiplatform/metadata/context.py | 59 +++ google/cloud/aiplatform/metadata/execution.py | 59 +++ google/cloud/aiplatform/metadata/metadata.py | 99 +++++ .../aiplatform/metadata/metadata_store.py | 117 +++++- google/cloud/aiplatform/metadata/resource.py | 369 ++++++++++++++++++ google/cloud/aiplatform/utils.py | 17 +- tests/unit/aiplatform/test_initializer.py | 21 +- .../aiplatform/test_metadata_resources.py | 298 ++++++++++++++ ...est_metadata.py => test_metadata_store.py} | 4 +- tests/unit/aiplatform/test_utils.py | 38 ++ 14 files changed, 1179 insertions(+), 27 deletions(-) create mode 100644 google/cloud/aiplatform/metadata/artifact.py create mode 100644 google/cloud/aiplatform/metadata/constants.py create mode 100644 google/cloud/aiplatform/metadata/context.py create mode 100644 google/cloud/aiplatform/metadata/execution.py create mode 100644 google/cloud/aiplatform/metadata/metadata.py create mode 100644 google/cloud/aiplatform/metadata/resource.py create mode 100644 tests/unit/aiplatform/test_metadata_resources.py rename tests/unit/aiplatform/{test_metadata.py => test_metadata_store.py} (98%) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 9c94090548..73a6342b99 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -38,6 +38,7 @@ AutoMLTextTrainingJob, AutoMLVideoTrainingJob, ) +from google.cloud.aiplatform.metadata import metadata """ Usage: @@ -47,10 +48,26 @@ """ init = initializer.global_config.init +log_param = metadata.metadata_service.log_param +log_params = metadata.metadata_service.log_params +log_metric = metadata.metadata_service.log_metric +log_metrics = metadata.metadata_service.log_metrics +set_experiment = metadata.metadata_service.set_experiment +get_experiment = metadata.metadata_service.get_experiment +set_run = metadata.metadata_service.set_run + + __all__ = ( "explain", "gapic", "init", + "log_param", + "log_params", + "log_metric", + "log_metrics", + "get_experiment", + "set_experiment", + "set_run", "AutoMLImageTrainingJob", "AutoMLTabularTrainingJob", "AutoMLTextTrainingJob", diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index b84a006d02..c6dc61dd2e 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -31,6 +31,7 @@ from google.cloud.aiplatform import compat from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata import metadata from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec_compat, @@ -44,7 +45,6 @@ class _Config: def __init__(self): self._project = None - self._experiment = None self._location = None self._staging_bucket = None self._credentials = None @@ -56,6 +56,7 @@ def init( project: Optional[str] = None, location: Optional[str] = None, experiment: Optional[str] = None, + run: Optional[str] = None, staging_bucket: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, encryption_spec_key_name: Optional[str] = None, @@ -66,7 +67,8 @@ def init( project (str): The default project to use when making API calls. location (str): The default location to use when making API calls. If not set defaults to us-central-1 - experiment (str): The experiment to assign + experiment (str): The experiment name + run (str): The run name staging_bucket (str): The default staging bucket to use to stage artifacts when making API calls. In the form gs://... credentials (google.auth.crendentials.Credentials): The default custom @@ -88,8 +90,14 @@ def init( utils.validate_region(location) self._location = location if experiment: - logging.warning("Experiments currently not supported.") - self._experiment = experiment + metadata.metadata_service.set_experiment(experiment) + if run: + if not experiment: + raise ValueError( + "No experiment set. Provide an experiment for this run, e.g., aiplatform.init(" + "experiment='my-experiment')." + ) + metadata.metadata_service.set_run(run) if staging_bucket: self._staging_bucket = staging_bucket if credentials: @@ -153,11 +161,6 @@ def location(self) -> str: """Default location.""" return self._location or constants.DEFAULT_REGION - @property - def experiment(self) -> Optional[str]: - """Default experiment, if provided.""" - return self._experiment - @property def staging_bucket(self) -> Optional[str]: """Default staging bucket, if provided.""" diff --git a/google/cloud/aiplatform/metadata/artifact.py b/google/cloud/aiplatform/metadata/artifact.py new file mode 100644 index 0000000000..9f8505c71b --- /dev/null +++ b/google/cloud/aiplatform/metadata/artifact.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto +from typing import Optional, Dict + +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata.resource import _Resource + +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact + + +class _Artifact(_Resource): + """Metadata Artifact resource for AI Platform""" + + _resource_noun = "artifacts" + _getter_method = "get_artifact" + + @classmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + gapic_artifact = gca_artifact.Artifact( + schema_title=schema_title, + schema_version=schema_version, + display_name=display_name, + description=description, + metadata=metadata if metadata else {}, + ) + return client.create_artifact( + parent=parent, artifact=gapic_artifact, artifact_id=resource_id, + ) + + @classmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + return client.update_artifact(artifact=resource) diff --git a/google/cloud/aiplatform/metadata/constants.py b/google/cloud/aiplatform/metadata/constants.py new file mode 100644 index 0000000000..cc905e8b9e --- /dev/null +++ b/google/cloud/aiplatform/metadata/constants.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SYSTEM_RUN = "system.Run" +SYSTEM_EXPERIMENT = "system.Experiment" +SYSTEM_METRICS = "system.Metrics" + +_DEFAULT_SCHEMA_VERSION = "0.0.1" + +SCHEMA_VERSIONS = { + SYSTEM_RUN: _DEFAULT_SCHEMA_VERSION, + SYSTEM_EXPERIMENT: _DEFAULT_SCHEMA_VERSION, + SYSTEM_METRICS: _DEFAULT_SCHEMA_VERSION, +} diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py new file mode 100644 index 0000000000..11cb297365 --- /dev/null +++ b/google/cloud/aiplatform/metadata/context.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto +from typing import Optional, Dict + +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata.resource import _Resource + +from google.cloud.aiplatform_v1beta1.types import context as gca_context + + +class _Context(_Resource): + """Metadata Context resource for AI Platform""" + + _resource_noun = "contexts" + _getter_method = "get_context" + + @classmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + gapic_context = gca_context.Context( + schema_title=schema_title, + schema_version=schema_version, + display_name=display_name, + description=description, + metadata=metadata if metadata else {}, + ) + return client.create_context( + parent=parent, context=gapic_context, context_id=resource_id, + ) + + @classmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + return client.update_context(context=resource) diff --git a/google/cloud/aiplatform/metadata/execution.py b/google/cloud/aiplatform/metadata/execution.py new file mode 100644 index 0000000000..060d562660 --- /dev/null +++ b/google/cloud/aiplatform/metadata/execution.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto +from typing import Optional, Dict + +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata.resource import _Resource + +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution + + +class _Execution(_Resource): + """Metadata Execution resource for AI Platform""" + + _resource_noun = "executions" + _getter_method = "get_execution" + + @classmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + gapic_execution = gca_execution.Execution( + schema_title=schema_title, + schema_version=schema_version, + display_name=display_name, + description=description, + metadata=metadata if metadata else {}, + ) + return client.create_execution( + parent=parent, execution=gapic_execution, execution_id=resource_id, + ) + + @classmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + return client.update_execution(execution=resource) diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py new file mode 100644 index 0000000000..069b7009a7 --- /dev/null +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, Union + +from google.cloud.aiplatform.metadata.metadata_store import _MetadataStore +from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata.context import _Context +from google.cloud.aiplatform.metadata.execution import _Execution +from google.cloud.aiplatform.metadata.artifact import _Artifact + + +class _MetadataService: + """Contains the exposed APIs to interact with the Managed Metadata Service.""" + + def __init__(self): + self._experiment = None + self._run = None + + def set_experiment(self, experiment: str): + _MetadataStore.get_or_create() + context = _Context.get_or_create( + resource_id=experiment, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + ) + self._experiment = context.name + + def set_run(self, run: str): + if not self._experiment: + raise ValueError( + "No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') or " + "aiplatform.set_experiment(experiment='my-experiment') before trying to set_run. " + ) + execution = _Execution.get_or_create( + resource_id=run, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + self._run = execution.name + + def log_param(self, name: str, value: Union[float, int, str]): + return self.log_params({name: value}) + + def log_params(self, params: Dict[str, Union[float, int, str]]): + self._validate_experiment_and_run(method_name="log_params") + execution = _Execution.get_or_create( + resource_id=self._run, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + execution.update(metadata=params) + self._run = execution.name + + def log_metric(self, name: str, value: Union[str, float, int]): + return self.log_metrics({name: value}) + + def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): + self._validate_experiment_and_run(method_name="log_metrics") + # Only one metrics artifact for the (experiment, run) tuple. + artifact_id = f"{self._experiment}-{self._run}" + artifact = _Artifact.get_or_create( + resource_id=artifact_id, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + artifact.update(metadata=metrics) + + def get_experiment(self, experiment: str): + raise NotImplementedError("get_experiment not implemented") + + def _validate_experiment_and_run(self, method_name: str): + if not self._experiment: + raise ValueError( + f"No experiment set. Make sure to call aiplatform.init(experiment='my-experiment') " + f"or aiplatform.set_experiment(experiment='my-experiment') before trying to {method_name}. " + ) + if not self._run: + raise ValueError( + f"No run set. Make sure to call aiplatform.init(experiment='my-experiment', " + f"run='my-run') or aiplatform.set_run('my-run') before trying to {method_name}. " + ) + + +metadata_service = _MetadataService() diff --git a/google/cloud/aiplatform/metadata/metadata_store.py b/google/cloud/aiplatform/metadata/metadata_store.py index 3c187813a1..3c3664b1d5 100644 --- a/google/cloud/aiplatform/metadata/metadata_store.py +++ b/google/cloud/aiplatform/metadata/metadata_store.py @@ -68,7 +68,67 @@ def __init__( self._gca_resource = self._get_gca_resource(resource_name=metadata_store_name) @classmethod - def create( + def get_or_create( + cls, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + ) -> "_MetadataStore": + """"Retrieves or Creates (if it does not exist) a Metadata Store. + + Args: + metadata_store_id (str): + The portion of the resource name with the format: + projects/123/locations/us-central1/metadataStores/ + If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore. + project (str): + Project used to retrieve or create the metadata store. Overrides project set in + aiplatform.init. + location (str): + Location used to retrieve or create the metadata store. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to retrieve or create the metadata store. Overrides + credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the metadata store. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + + + Returns: + metadata_store (_MetadataStore): + Instantiated representation of the managed metadata store resource. + + """ + + store = cls._get( + metadata_store_name=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + if not store: + store = cls._create( + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + encryption_spec_key_name=encryption_spec_key_name, + ) + return store + + @classmethod + def _create( cls, metadata_store_id: str = "default", project: Optional[str] = None, @@ -80,18 +140,18 @@ def create( Args: metadata_store_id (str): - The {metadatastore} portion of the resource name with + The portion of the resource name with the format: - projects/{project}/locations/{location}/metadataStores/{metadatastore} + projects/123/locations/us-central1/metadataStores/ If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore. project (str): - Project to upload this model to. Overrides project set in + Project used to create the metadata store. Overrides project set in aiplatform.init. location (str): - Location to upload this model to. Overrides location set in + Location used to create the metadata store. Overrides location set in aiplatform.init. credentials (auth_credentials.Credentials): - Custom credentials to use to upload this model. Overrides + Custom credentials used to create the metadata store. Overrides credentials set in aiplatform.init. encryption_spec_key_name (Optional[str]): Optional. The Cloud KMS resource identifier of the customer @@ -107,7 +167,7 @@ def create( Returns: - metadata_store (MetadataStore): + metadata_store (_MetadataStore): Instantiated representation of the managed metadata store resource. """ @@ -128,7 +188,7 @@ def create( metadata_store_id=metadata_store_id, ).result() except exceptions.AlreadyExists: - logging.info("MetadataStore %s already exists" % metadata_store_id) + logging.info(f"MetadataStore '{metadata_store_id}' already exists") return cls( metadata_store_name=metadata_store_id, @@ -136,3 +196,44 @@ def create( location=location, credentials=credentials, ) + + @classmethod + def _get( + cls, + metadata_store_name: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "Optional[_MetadataStore]": + """Returns a MetadataStore resource. + + Args: + metadata_store_name (str): + Optional. A fully-qualified MetadataStore resource name or metadataStore ID. + Example: "projects/123/locations/us-central1/metadataStores/my-store" or + "my-store" when project and location are initialized or passed. + If not set, metadata_store_name will be set to "default". + project (str): + Optional project to retrieve the metadata store from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve the metadata store from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to retrieve this metadata store. Overrides + credentials set in aiplatform.init. + + Returns: + metadata_store (Optional[_MetadataStore]): + An optional instantiated representation of the managed Metadata Store resource. + """ + + try: + return cls( + metadata_store_name=metadata_store_name, + project=project, + location=location, + credentials=credentials, + ) + except exceptions.NotFound: + logging.info(f"MetadataStore {metadata_store_name} not found.") diff --git a/google/cloud/aiplatform/metadata/resource.py b/google/cloud/aiplatform/metadata/resource.py new file mode 100644 index 0000000000..ab0a3a9aa4 --- /dev/null +++ b/google/cloud/aiplatform/metadata/resource.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import re +import proto +import logging +from typing import Optional, Dict +from copy import deepcopy + +from google.api_core import exceptions +from google.cloud.aiplatform import utils +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base, initializer + + +class _Resource(base.AiPlatformResourceNounWithFutureManager, abc.ABC): + """Metadata Resource for AI Platform""" + + client_class = utils.MetadataClientWithOverride + _is_client_prediction_client = False + _delete_method = None + + def __init__( + self, + resource_name: str, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing Metadata resource given a resource name or ID. + + Args: + resource_name (str): + A fully-qualified resource name or ID + Example: "projects/123/locations/us-central1/metadataStores/default//my-resource". + or "my-resource" when project and location are initialized or passed. + metadata_store_id (str): + MetadataStore to retrieve resource from. If not set, metadata_store_id is set to "default". + If resource_name is a fully-qualified resource, its metadata_store_id overrides this one. + project (str): + Optional project to retrieve the resource from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve the resource from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + """ + + super().__init__( + project=project, location=location, credentials=credentials, + ) + + # If we receive a full resource name, we extract the metadata_store_id and use that + if "/" in resource_name: + metadata_store_id = _Resource._extract_metadata_store_id( + resource_name, self._resource_noun + ) + + full_resource_name = utils.full_resource_name( + resource_name=resource_name, + resource_noun=f"metadataStores/{metadata_store_id}/{self._resource_noun}", + project=self.project, + location=self.location, + ) + + self._gca_resource = getattr(self.api_client, self._getter_method)( + name=full_resource_name + ) + + @classmethod + def get_or_create( + cls, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "_Resource": + """Retrieves or Creates (if it does not exist) a Metadata resource. + + Args: + resource_id (str): + Required. The portion of the resource name with the format: + projects/123/locations/us-central1/metadataStores///. + schema_title (str): + Required. schema_title identifies the schema title used by the resource. + display_name (str): + Optional. The user-defined name of the resource. + schema_version (str): + Optional. schema_version specifies the version used by the resource. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the resource to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the resource. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to retrieve or create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to retrieve or create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to retrieve or create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (_Resource): + Instantiated representation of the managed Metadata resource. + + """ + + resource = cls._get( + resource_name=resource_id, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + if not resource: + logging.info(f"Creating Resource {resource_id}") + resource = cls._create( + resource_id=resource_id, + schema_title=schema_title, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + return resource + + def update( + self, + metadata: Dict, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Updates an existing Metadata resource with new metadata. + + Args: + metadata (Dict): + Required. metadata contains the updated metadata information. + credentials (auth_credentials.Credentials): + Custom credentials to use to update this resource. Overrides + credentials set in aiplatform.init. + + """ + + gca_resource = deepcopy(self._gca_resource) + gca_resource.metadata.update(metadata) + api_client = self._instantiate_client(credentials=credentials) + + update_gca_resource = self._update_resource( + client=api_client, resource=gca_resource, + ) + self._gca_resource = update_gca_resource + + @classmethod + def _create( + cls, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Creates a new Metadata resource. + + Args: + resource_id (str): + Required. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores///. + schema_title (str): + Required. schema_title identifies the schema title used by the resource. + display_name (str): + Optional. The user-defined name of the resource. + schema_version (str): + Optional. schema_version specifies the version used by the resource. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the resource to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the resource. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (_Resource): + Instantiated representation of the managed Metadata resource. + + """ + api_client = cls._instantiate_client(location=location, credentials=credentials) + + parent = ( + initializer.global_config.common_location_path( + project=project, location=location + ) + + f"/metadataStores/{metadata_store_id}" + ) + + try: + cls._create_resource( + client=api_client, + parent=parent, + resource_id=resource_id, + schema_title=schema_title, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + ) + except exceptions.AlreadyExists: + logging.info(f"Resource '{resource_id}' already exist") + + return cls( + resource_name=f"{parent}/{cls._resource_noun}/{resource_id}", + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + + @classmethod + def _get( + cls, + resource_name: str, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Optional["_Resource"]: + """Returns a metadata Resource. + + Args: + resource_name (str): + A fully-qualified resource name or resource ID + Example: "projects/123/locations/us-central1/metadataStores/default//my-resource". + or "my-resource" when project and location are initialized or passed. + metadata_store_id (str): + The metadata_store_id portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores///my-resource + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project to get this resource from. Overrides project set in + aiplatform.init. + location (str): + Location to get this resource from. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to get this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (Optional[_Resource]): + An optional instantiated representation of the managed Metadata resource. + + """ + + try: + return cls( + resource_name=resource_name, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + except exceptions.NotFound: + logging.info(f"Resource {resource_name} not found.") + + @classmethod + @abc.abstractmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + """Create resource method.""" + pass + + @classmethod + @abc.abstractmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + """Update resource method.""" + pass + + @staticmethod + def _extract_metadata_store_id(resource_name, resource_noun) -> str: + """Extracts the metadata store id from the resource name. + + Args: + resource_name (str): + Required. A fully-qualified metadata resource name. For example + projects/{project}/locations/{location}/metadataStores/{metadata_store_id}/{resource_noun}/{resource_id}. + resource_noun (str): + Required. The resource_noun portion of the resource_name + Returns: + metadata_store_id (str): + The metadata store id for the particular resource name. + Raises: + ValueError if it does not exist. + """ + pattern = re.compile( + r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/metadataStores\/(?P[\w-]+)\/" + + resource_noun + + r"\/(?P[\w-]+)$" + ) + match = pattern.match(resource_name) + if not match: + raise ValueError( + f"failed to extract metadata_store_id from resource {resource_name}" + ) + return match["store"] diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index 9450d3f425..a7e96b776d 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -71,9 +71,8 @@ job_service_client_v1.JobServiceClient, ) -# TODO(b/170334098): Add support for resource names more than one level deep RESOURCE_NAME_PATTERN = re.compile( - r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P\w+)\/(?P[\w-]+)$" + r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P[\w\-\/]+)\/(?P[\w-]+)$" ) RESOURCE_ID_PATTERN = re.compile(r"^[\w-]+$") @@ -109,10 +108,12 @@ def extract_fields_from_resource_name( Required. A fully-qualified AI Platform (Unified) resource name resource_noun (str): - A plural resource noun to validate the resource name against. + A resource noun to validate the resource name against. For example, you would pass "datasets" to validate "projects/123/locations/us-central1/datasets/456". - + In the case of deeper naming structures, e.g., + "projects/123/locations/us-central1/metadataStores/123/contexts/456", + you would pass "metadataStores/123/contexts" as the resource_noun. Returns: fields (Fields): A named tuple containing four extracted fields from a resource name: @@ -143,9 +144,12 @@ def full_resource_name( Required. A fully-qualified AI Platform (Unified) resource name or resource ID. resource_noun (str): - A plural resource noun to validate the resource name against. + A resource noun to validate the resource name against. For example, you would pass "datasets" to validate "projects/123/locations/us-central1/datasets/456". + In the case of deeper naming structures, e.g., + "projects/123/locations/us-central1/metadataStores/123/contexts/456", + you would pass "metadataStores/123/contexts" as the resource_noun. project (str): Optional project to retrieve resource_noun from. If not set, project set in aiplatform.init will be used. @@ -162,7 +166,8 @@ def full_resource_name( If resource name, resource ID or project ID not provided. """ validate_resource_noun(resource_noun) - # Fully qualified resource name, i.e. "projects/.../locations/.../datasets/12345" + # Fully qualified resource name, e.g., "projects/.../locations/.../datasets/12345" or + # "projects/.../locations/.../metadataStores/.../contexts/12345" valid_name = extract_fields_from_resource_name( resource_name=resource_name, resource_noun=resource_noun ) diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 1d97ad2e9a..dcb33f5b42 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -19,11 +19,13 @@ import os import pytest from unittest import mock +from unittest.mock import patch import google.auth from google.auth import credentials from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata.metadata import metadata_service from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils @@ -37,6 +39,7 @@ _TEST_LOCATION_2 = "europe-west4" _TEST_INVALID_LOCATION = "test-invalid-location" _TEST_EXPERIMENT = "test-experiment" +_TEST_RUN = "test-run" _TEST_STAGING_BUCKET = "test-bucket" @@ -69,9 +72,23 @@ def test_init_location_with_invalid_location_raises(self): with pytest.raises(ValueError): initializer.global_config.init(location=_TEST_INVALID_LOCATION) - def test_init_experiment_sets_experiment(self): + @patch.object(metadata_service, "set_experiment") + def test_init_experiment_calls_metadata_service(self, set_experiment_mock): initializer.global_config.init(experiment=_TEST_EXPERIMENT) - assert initializer.global_config.experiment == _TEST_EXPERIMENT + set_experiment_mock.assert_called_once_with(_TEST_EXPERIMENT) + + def test_init_run_alone_without_experiment_raises(self): + with pytest.raises(ValueError): + initializer.global_config.init(run=_TEST_RUN) + + @patch.object(metadata_service, "set_run") + @patch.object(metadata_service, "set_experiment") + def test_init_experiment_and_run_calls_metadata_service( + self, set_experiment_mock, set_run_mock + ): + initializer.global_config.init(experiment=_TEST_EXPERIMENT, run=_TEST_RUN) + set_experiment_mock.assert_called_once_with(_TEST_EXPERIMENT) + set_run_mock.assert_called_once_with(_TEST_RUN) def test_init_staging_bucket_sets_staging_bucket(self): initializer.global_config.init(staging_bucket=_TEST_STAGING_BUCKET) diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py new file mode 100644 index 0000000000..dca0c569fa --- /dev/null +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from importlib import reload +from unittest.mock import patch + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata import context +from google.cloud.aiplatform.metadata import artifact +from google.cloud.aiplatform.metadata import execution + +from google.cloud.aiplatform_v1beta1 import MetadataServiceClient +from google.cloud.aiplatform_v1beta1 import Context as GapicContext +from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution +from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_METADATA_STORE = "test-metadata-store" +_TEST_ALT_LOCATION = "europe-west4" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/{_TEST_METADATA_STORE}" + +# resource attributes +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_SCHEMA_TITLE = "test.Example" +_TEST_SCHEMA_VERSION = "0.0.1" +_TEST_DESCRIPTION = "test description" +_TEST_METADATA = {"test-param1": 1, "test-param2": "test-value", "test-param3": True} + +# context +_TEST_CONTEXT_ID = "test-context-id" +_TEST_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_CONTEXT_ID}" + +# artifact +_TEST_ARTIFACT_ID = "test-artifact-id" +_TEST_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_ARTIFACT_ID}" + +# execution +_TEST_EXECUTION_ID = "test-execution-id" +_TEST_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_EXECUTION_ID}" + + +@pytest.fixture +def get_context_mock(): + with patch.object(MetadataServiceClient, "get_context") as get_context_mock: + get_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield get_context_mock + + +@pytest.fixture +def create_context_mock(): + with patch.object(MetadataServiceClient, "create_context") as create_context_mock: + create_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield create_context_mock + + +@pytest.fixture +def get_execution_mock(): + with patch.object(MetadataServiceClient, "get_execution") as get_execution_mock: + get_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield get_execution_mock + + +@pytest.fixture +def create_execution_mock(): + with patch.object( + MetadataServiceClient, "create_execution" + ) as create_execution_mock: + create_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield create_execution_mock + + +@pytest.fixture +def get_artifact_mock(): + with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock: + get_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield get_artifact_mock + + +@pytest.fixture +def create_artifact_mock(): + with patch.object(MetadataServiceClient, "create_artifact") as create_artifact_mock: + create_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield create_artifact_mock + + +class TestContext: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_context(self, get_context_mock): + aiplatform.init(project=_TEST_PROJECT) + context._Context(resource_name=_TEST_CONTEXT_NAME) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + def test_init_context_with_id(self, get_context_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + context._Context( + resource_name=_TEST_CONTEXT_ID, metadata_store_id=_TEST_METADATA_STORE + ) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + @pytest.mark.usefixtures("get_context_mock") + def test_create_context(self, create_context_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context._Context._create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_context = GapicContext( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + create_context_mock.assert_called_once_with( + parent=_TEST_PARENT, context_id=_TEST_CONTEXT_ID, context=expected_context, + ) + + expected_context.name = _TEST_CONTEXT_NAME + assert my_context._gca_resource == expected_context + + +class TestExecution: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_execution(self, get_execution_mock): + aiplatform.init(project=_TEST_PROJECT) + execution._Execution(resource_name=_TEST_EXECUTION_NAME) + get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) + + def test_init_execution_with_id(self, get_execution_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + execution._Execution( + resource_name=_TEST_EXECUTION_ID, metadata_store_id=_TEST_METADATA_STORE + ) + get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) + + @pytest.mark.usefixtures("get_execution_mock") + def test_create_execution(self, create_execution_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_execution = execution._Execution._create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_execution = GapicExecution( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + create_execution_mock.assert_called_once_with( + parent=_TEST_PARENT, + execution_id=_TEST_EXECUTION_ID, + execution=expected_execution, + ) + + expected_execution.name = _TEST_EXECUTION_NAME + assert my_execution._gca_resource == expected_execution + + +class TestArtifact: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_artifact(self, get_artifact_mock): + aiplatform.init(project=_TEST_PROJECT) + artifact._Artifact(resource_name=_TEST_ARTIFACT_NAME) + get_artifact_mock.assert_called_once_with(name=_TEST_ARTIFACT_NAME) + + def test_init_artifact_with_id(self, get_artifact_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + artifact._Artifact( + resource_name=_TEST_ARTIFACT_ID, metadata_store_id=_TEST_METADATA_STORE + ) + get_artifact_mock.assert_called_once_with(name=_TEST_ARTIFACT_NAME) + + @pytest.mark.usefixtures("get_artifact_mock") + def test_create_artifact(self, create_artifact_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_artifact = artifact._Artifact._create( + resource_id=_TEST_ARTIFACT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_artifact = GapicArtifact( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + create_artifact_mock.assert_called_once_with( + parent=_TEST_PARENT, + artifact_id=_TEST_ARTIFACT_ID, + artifact=expected_artifact, + ) + + expected_artifact.name = _TEST_ARTIFACT_NAME + assert my_artifact._gca_resource == expected_artifact diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata_store.py similarity index 98% rename from tests/unit/aiplatform/test_metadata.py rename to tests/unit/aiplatform/test_metadata_store.py index 8da3539380..d8c38b8baf 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata_store.py @@ -190,7 +190,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_default_metadata_st project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, ) - my_metadata_store = metadata_store._MetadataStore.create( + my_metadata_store = metadata_store._MetadataStore._create( encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, ) @@ -211,7 +211,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_default_metadata_st def test_create_non_default_metadata_store(self, create_metadata_store_mock): aiplatform.init(project=_TEST_PROJECT) - my_metadata_store = metadata_store._MetadataStore.create( + my_metadata_store = metadata_store._MetadataStore._create( metadata_store_id=_TEST_ID, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, ) diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 03ca7cd6fe..c5ce327db8 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -102,6 +102,18 @@ def test_extract_fields_from_resource_name_with_extracted_fields( "batchPredictionJobs", False, ), + # Expects pattern "projects/.../locations/.../metadataStores/.../contexts/..." + ( + "projects/857392/locations/us-central1/metadataStores/default/contexts/123", + "metadataStores/default/contexts", + True, + ), + # Expects pattern "projects/.../locations/.../tensorboards/.../experiments/.../runs/.../timeSeries/..." + ( + "projects/857392/locations/us-central1/tensorboards/123/experiments/456/runs/789/timeSeries/1", + "tensorboards/123/experiments/456/runs/789/timeSeries", + True, + ), ], ) def test_extract_fields_from_resource_name_with_resource_noun( @@ -141,6 +153,18 @@ def test_invalid_region_does_not_raise_with_valid_region(): "us-west20", "projects/857392/locations/us-central1/trainingPipelines/347292", ), + ( + "metadataStores/default/contexts", + "123456", + "europe-west4", + "projects/857392/locations/us-central1/metadataStores/default/contexts/123", + ), + ( + "tensorboards/123/experiments/456/runs/789/timeSeries", + "857392", + "us-central1", + "projects/857392/locations/us-central1/tensorboards/123/experiments/456/runs/789/timeSeries/1", + ), ], ) def test_full_resource_name_with_full_name( @@ -175,6 +199,20 @@ def test_full_resource_name_with_full_name( "us-central1", "projects/857392/locations/us-central1/trainingPipelines/347292", ), + ( + "123", + "metadataStores/default/contexts", + "857392", + "us-central1", + "projects/857392/locations/us-central1/metadataStores/default/contexts/123", + ), + ( + "1", + "tensorboards/123/experiments/456/runs/789/timeSeries", + "857392", + "us-central1", + "projects/857392/locations/us-central1/tensorboards/123/experiments/456/runs/789/timeSeries/1", + ), ], ) def test_full_resource_name_with_partial_name( From 09bec9137f84f8bdc2b5465351b65b422f9df02a Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 13 Apr 2021 14:37:59 -0400 Subject: [PATCH 08/36] test --- .../index_endpoint_service.rst | 11 + docs/aiplatform_v1beta1/index_service.rst | 11 + docs/aiplatform_v1beta1/services.rst | 2 + google/cloud/aiplatform_v1beta1/__init__.py | 60 + .../index_endpoint_service/__init__.py | 24 + .../index_endpoint_service/async_client.py | 843 +++++ .../services/index_endpoint_service/client.py | 1040 ++++++ .../services/index_endpoint_service/pagers.py | 143 + .../transports/__init__.py | 35 + .../index_endpoint_service/transports/base.py | 220 ++ .../index_endpoint_service/transports/grpc.py | 429 +++ .../transports/grpc_asyncio.py | 434 +++ .../services/index_service/__init__.py | 24 + .../services/index_service/async_client.py | 653 ++++ .../services/index_service/client.py | 839 +++++ .../services/index_service/pagers.py | 143 + .../index_service/transports/__init__.py | 35 + .../services/index_service/transports/base.py | 191 ++ .../services/index_service/transports/grpc.py | 375 +++ .../index_service/transports/grpc_asyncio.py | 380 +++ .../aiplatform_v1beta1/types/__init__.py | 66 + .../types/deployed_index_ref.py | 46 + .../cloud/aiplatform_v1beta1/types/index.py | 127 + .../types/index_endpoint.py | 262 ++ .../types/index_endpoint_service.py | 302 ++ .../aiplatform_v1beta1/types/index_service.py | 303 ++ .../snippets/create_custom_job_sample_test.py | 2 +- ...r_tuning_job_python_package_sample_test.py | 2 +- ...e_hyperparameter_tuning_job_sample_test.py | 2 +- ...om_training_managed_dataset_sample_test.py | 2 +- .../test_index_endpoint_service.py | 2863 +++++++++++++++++ .../aiplatform_v1beta1/test_index_service.py | 2317 +++++++++++++ 32 files changed, 12182 insertions(+), 4 deletions(-) create mode 100644 docs/aiplatform_v1beta1/index_endpoint_service.rst create mode 100644 docs/aiplatform_v1beta1/index_service.rst create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/pagers.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py create mode 100644 google/cloud/aiplatform_v1beta1/types/index.py create mode 100644 google/cloud/aiplatform_v1beta1/types/index_endpoint.py create mode 100644 google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py create mode 100644 google/cloud/aiplatform_v1beta1/types/index_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_index_service.py diff --git a/docs/aiplatform_v1beta1/index_endpoint_service.rst b/docs/aiplatform_v1beta1/index_endpoint_service.rst new file mode 100644 index 0000000000..2389e5bf64 --- /dev/null +++ b/docs/aiplatform_v1beta1/index_endpoint_service.rst @@ -0,0 +1,11 @@ +IndexEndpointService +-------------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.index_endpoint_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.index_endpoint_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/index_service.rst b/docs/aiplatform_v1beta1/index_service.rst new file mode 100644 index 0000000000..e42ade6eaa --- /dev/null +++ b/docs/aiplatform_v1beta1/index_service.rst @@ -0,0 +1,11 @@ +IndexService +------------------------------ + +.. automodule:: google.cloud.aiplatform_v1beta1.services.index_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.index_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index 7197956571..f715a7c1f4 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -7,6 +7,8 @@ Services for Google Cloud Aiplatform v1beta1 API endpoint_service featurestore_online_serving_service featurestore_service + index_endpoint_service + index_service job_service metadata_service migration_service diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 0ec5663b24..968477c4d7 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -19,6 +19,8 @@ from .services.endpoint_service import EndpointServiceClient from .services.featurestore_online_serving_service import FeaturestoreOnlineServingServiceClient from .services.featurestore_service import FeaturestoreServiceClient +from .services.index_endpoint_service import IndexEndpointServiceClient +from .services.index_service import IndexServiceClient from .services.job_service import JobServiceClient from .services.metadata_service import MetadataServiceClient from .services.migration_service import MigrationServiceClient @@ -66,6 +68,7 @@ from .types.dataset_service import ListDatasetsRequest from .types.dataset_service import ListDatasetsResponse from .types.dataset_service import UpdateDatasetRequest +from .types.deployed_index_ref import DeployedIndexRef from .types.deployed_model_ref import DeployedModelRef from .types.encryption_spec import EncryptionSpec from .types.endpoint import DeployedModel @@ -148,6 +151,33 @@ from .types.featurestore_service import UpdateFeaturestoreOperationMetadata from .types.featurestore_service import UpdateFeaturestoreRequest from .types.hyperparameter_tuning_job import HyperparameterTuningJob +from .types.index import Index +from .types.index_endpoint import DeployedIndex +from .types.index_endpoint import DeployedIndexAuthConfig +from .types.index_endpoint import IndexEndpoint +from .types.index_endpoint import IndexPrivateEndpoints +from .types.index_endpoint_service import CreateIndexEndpointOperationMetadata +from .types.index_endpoint_service import CreateIndexEndpointRequest +from .types.index_endpoint_service import DeleteIndexEndpointRequest +from .types.index_endpoint_service import DeployIndexOperationMetadata +from .types.index_endpoint_service import DeployIndexRequest +from .types.index_endpoint_service import DeployIndexResponse +from .types.index_endpoint_service import GetIndexEndpointRequest +from .types.index_endpoint_service import ListIndexEndpointsRequest +from .types.index_endpoint_service import ListIndexEndpointsResponse +from .types.index_endpoint_service import UndeployIndexOperationMetadata +from .types.index_endpoint_service import UndeployIndexRequest +from .types.index_endpoint_service import UndeployIndexResponse +from .types.index_endpoint_service import UpdateIndexEndpointRequest +from .types.index_service import CreateIndexOperationMetadata +from .types.index_service import CreateIndexRequest +from .types.index_service import DeleteIndexRequest +from .types.index_service import GetIndexRequest +from .types.index_service import ListIndexesRequest +from .types.index_service import ListIndexesResponse +from .types.index_service import NearestNeighborSearchOperationMetadata +from .types.index_service import UpdateIndexOperationMetadata +from .types.index_service import UpdateIndexRequest from .types.io import AvroSource from .types.io import BigQueryDestination from .types.io import BigQuerySource @@ -403,6 +433,10 @@ 'CreateFeaturestoreOperationMetadata', 'CreateFeaturestoreRequest', 'CreateHyperparameterTuningJobRequest', + 'CreateIndexEndpointOperationMetadata', + 'CreateIndexEndpointRequest', + 'CreateIndexOperationMetadata', + 'CreateIndexRequest', 'CreateMetadataSchemaRequest', 'CreateMetadataStoreOperationMetadata', 'CreateMetadataStoreRequest', @@ -431,6 +465,8 @@ 'DeleteFeatureRequest', 'DeleteFeaturestoreRequest', 'DeleteHyperparameterTuningJobRequest', + 'DeleteIndexEndpointRequest', + 'DeleteIndexRequest', 'DeleteMetadataStoreOperationMetadata', 'DeleteMetadataStoreRequest', 'DeleteModelDeploymentMonitoringJobRequest', @@ -440,9 +476,15 @@ 'DeleteStudyRequest', 'DeleteTrainingPipelineRequest', 'DeleteTrialRequest', + 'DeployIndexOperationMetadata', + 'DeployIndexRequest', + 'DeployIndexResponse', 'DeployModelOperationMetadata', 'DeployModelRequest', 'DeployModelResponse', + 'DeployedIndex', + 'DeployedIndexAuthConfig', + 'DeployedIndexRef', 'DeployedModel', 'DeployedModelRef', 'DestinationFeatureSetting', @@ -499,6 +541,8 @@ 'GetFeatureRequest', 'GetFeaturestoreRequest', 'GetHyperparameterTuningJobRequest', + 'GetIndexEndpointRequest', + 'GetIndexRequest', 'GetMetadataSchemaRequest', 'GetMetadataStoreRequest', 'GetModelDeploymentMonitoringJobRequest', @@ -518,6 +562,11 @@ 'ImportFeatureValuesOperationMetadata', 'ImportFeatureValuesRequest', 'ImportFeatureValuesResponse', + 'Index', + 'IndexEndpoint', + 'IndexEndpointServiceClient', + 'IndexPrivateEndpoints', + 'IndexServiceClient', 'InputDataConfig', 'Int64Array', 'IntegratedGradientsAttribution', @@ -552,6 +601,10 @@ 'ListFeaturestoresResponse', 'ListHyperparameterTuningJobsRequest', 'ListHyperparameterTuningJobsResponse', + 'ListIndexEndpointsRequest', + 'ListIndexEndpointsResponse', + 'ListIndexesRequest', + 'ListIndexesResponse', 'ListMetadataSchemasRequest', 'ListMetadataSchemasResponse', 'ListMetadataStoresRequest', @@ -598,6 +651,7 @@ 'ModelMonitoringObjectiveConfig', 'ModelMonitoringStatsAnomalies', 'ModelServiceClient', + 'NearestNeighborSearchOperationMetadata', 'PauseModelDeploymentMonitoringJobRequest', 'PipelineServiceClient', 'PipelineState', @@ -643,6 +697,9 @@ 'TrainingConfig', 'TrainingPipeline', 'Trial', + 'UndeployIndexOperationMetadata', + 'UndeployIndexRequest', + 'UndeployIndexResponse', 'UndeployModelOperationMetadata', 'UndeployModelRequest', 'UndeployModelResponse', @@ -655,6 +712,9 @@ 'UpdateFeatureRequest', 'UpdateFeaturestoreOperationMetadata', 'UpdateFeaturestoreRequest', + 'UpdateIndexEndpointRequest', + 'UpdateIndexOperationMetadata', + 'UpdateIndexRequest', 'UpdateModelDeploymentMonitoringJobOperationMetadata', 'UpdateModelDeploymentMonitoringJobRequest', 'UpdateModelRequest', diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py new file mode 100644 index 0000000000..853d7b928c --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import IndexEndpointServiceClient +from .async_client import IndexEndpointServiceAsyncClient + +__all__ = ( + 'IndexEndpointServiceClient', + 'IndexEndpointServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py new file mode 100644 index 0000000000..704dd1fda4 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py @@ -0,0 +1,843 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.types import index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import IndexEndpointServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import IndexEndpointServiceGrpcAsyncIOTransport +from .client import IndexEndpointServiceClient + + +class IndexEndpointServiceAsyncClient: + """A service for managing AI Platform's IndexEndpoints.""" + + _client: IndexEndpointServiceClient + + DEFAULT_ENDPOINT = IndexEndpointServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = IndexEndpointServiceClient.DEFAULT_MTLS_ENDPOINT + + index_path = staticmethod(IndexEndpointServiceClient.index_path) + parse_index_path = staticmethod(IndexEndpointServiceClient.parse_index_path) + index_endpoint_path = staticmethod(IndexEndpointServiceClient.index_endpoint_path) + parse_index_endpoint_path = staticmethod(IndexEndpointServiceClient.parse_index_endpoint_path) + index_endpoint_path = staticmethod(IndexEndpointServiceClient.index_endpoint_path) + parse_index_endpoint_path = staticmethod(IndexEndpointServiceClient.parse_index_endpoint_path) + + common_billing_account_path = staticmethod(IndexEndpointServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(IndexEndpointServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(IndexEndpointServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(IndexEndpointServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(IndexEndpointServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(IndexEndpointServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(IndexEndpointServiceClient.common_project_path) + parse_common_project_path = staticmethod(IndexEndpointServiceClient.parse_common_project_path) + + common_location_path = staticmethod(IndexEndpointServiceClient.common_location_path) + parse_common_location_path = staticmethod(IndexEndpointServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexEndpointServiceAsyncClient: The constructed client. + """ + return IndexEndpointServiceClient.from_service_account_info.__func__(IndexEndpointServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexEndpointServiceAsyncClient: The constructed client. + """ + return IndexEndpointServiceClient.from_service_account_file.__func__(IndexEndpointServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> IndexEndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + IndexEndpointServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(IndexEndpointServiceClient).get_transport_class, type(IndexEndpointServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, IndexEndpointServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the index endpoint service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.IndexEndpointServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = IndexEndpointServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_index_endpoint(self, + request: index_endpoint_service.CreateIndexEndpointRequest = None, + *, + parent: str = None, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates an IndexEndpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateIndexEndpointRequest`): + The request object. Request message for + ``IndexEndpointService.CreateIndexEndpoint``. + parent (:class:`str`): + Required. The resource name of the Location to create + the IndexEndpoint in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + index_endpoint (:class:`google.cloud.aiplatform_v1beta1.types.IndexEndpoint`): + Required. The IndexEndpoint to + create. + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.IndexEndpoint` Indexes are deployed into it. An IndexEndpoint can have multiple + DeployedIndexes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, index_endpoint]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_endpoint_service.CreateIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if index_endpoint is not None: + request.index_endpoint = index_endpoint + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_index_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_index_endpoint.IndexEndpoint, + metadata_type=index_endpoint_service.CreateIndexEndpointOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_index_endpoint(self, + request: index_endpoint_service.GetIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index_endpoint.IndexEndpoint: + r"""Gets an IndexEndpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetIndexEndpointRequest`): + The request object. Request message for + ``IndexEndpointService.GetIndexEndpoint`` + name (:class:`str`): + Required. The name of the IndexEndpoint resource. + Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.IndexEndpoint: + Indexes are deployed into it. An + IndexEndpoint can have multiple + DeployedIndexes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_endpoint_service.GetIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_index_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_index_endpoints(self, + request: index_endpoint_service.ListIndexEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexEndpointsAsyncPager: + r"""Lists IndexEndpoints in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsRequest`): + The request object. Request message for + ``IndexEndpointService.ListIndexEndpoints``. + parent (:class:`str`): + Required. The resource name of the Location from which + to list the IndexEndpoints. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.index_endpoint_service.pagers.ListIndexEndpointsAsyncPager: + Response message for + ``IndexEndpointService.ListIndexEndpoints``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_endpoint_service.ListIndexEndpointsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_index_endpoints, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListIndexEndpointsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_index_endpoint(self, + request: index_endpoint_service.UpdateIndexEndpointRequest = None, + *, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_index_endpoint.IndexEndpoint: + r"""Updates an IndexEndpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateIndexEndpointRequest`): + The request object. Request message for + ``IndexEndpointService.UpdateIndexEndpoint``. + index_endpoint (:class:`google.cloud.aiplatform_v1beta1.types.IndexEndpoint`): + Required. The IndexEndpoint which + replaces the resource on the server. + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. See + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.IndexEndpoint: + Indexes are deployed into it. An + IndexEndpoint can have multiple + DeployedIndexes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index_endpoint, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_endpoint_service.UpdateIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index_endpoint is not None: + request.index_endpoint = index_endpoint + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_index_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index_endpoint.name', request.index_endpoint.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_index_endpoint(self, + request: index_endpoint_service.DeleteIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an IndexEndpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteIndexEndpointRequest`): + The request object. Request message for + ``IndexEndpointService.DeleteIndexEndpoint``. + name (:class:`str`): + Required. The name of the IndexEndpoint resource to be + deleted. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_endpoint_service.DeleteIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_index_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def deploy_index(self, + request: index_endpoint_service.DeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index: gca_index_endpoint.DeployedIndex = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deploys an Index into this IndexEndpoint, creating a + DeployedIndex within it. + Only non-empty Indexes can be deployed. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeployIndexRequest`): + The request object. Request message for + ``IndexEndpointService.DeployIndex``. + index_endpoint (:class:`str`): + Required. The name of the IndexEndpoint resource into + which to deploy an Index. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_index (:class:`google.cloud.aiplatform_v1beta1.types.DeployedIndex`): + Required. The DeployedIndex to be + created within the IndexEndpoint. + + This corresponds to the ``deployed_index`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.DeployIndexResponse` + Response message for + ``IndexEndpointService.DeployIndex``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index_endpoint, deployed_index]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_endpoint_service.DeployIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index_endpoint is not None: + request.index_endpoint = index_endpoint + if deployed_index is not None: + request.deployed_index = deployed_index + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.deploy_index, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index_endpoint', request.index_endpoint), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + index_endpoint_service.DeployIndexResponse, + metadata_type=index_endpoint_service.DeployIndexOperationMetadata, + ) + + # Done; return the response. + return response + + async def undeploy_index(self, + request: index_endpoint_service.UndeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Undeploys an Index from an IndexEndpoint, removing a + DeployedIndex from it, and freeing all resources it's + using. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UndeployIndexRequest`): + The request object. Request message for + ``IndexEndpointService.UndeployIndex``. + index_endpoint (:class:`str`): + Required. The name of the IndexEndpoint resource from + which to undeploy an Index. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_index_id (:class:`str`): + Required. The ID of the DeployedIndex + to be undeployed from the IndexEndpoint. + + This corresponds to the ``deployed_index_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.UndeployIndexResponse` + Response message for + ``IndexEndpointService.UndeployIndex``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index_endpoint, deployed_index_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_endpoint_service.UndeployIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index_endpoint is not None: + request.index_endpoint = index_endpoint + if deployed_index_id is not None: + request.deployed_index_id = deployed_index_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.undeploy_index, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index_endpoint', request.index_endpoint), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + index_endpoint_service.UndeployIndexResponse, + metadata_type=index_endpoint_service.UndeployIndexOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'IndexEndpointServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py new file mode 100644 index 0000000000..9933c45371 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py @@ -0,0 +1,1040 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.types import index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import IndexEndpointServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import IndexEndpointServiceGrpcTransport +from .transports.grpc_asyncio import IndexEndpointServiceGrpcAsyncIOTransport + + +class IndexEndpointServiceClientMeta(type): + """Metaclass for the IndexEndpointService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[IndexEndpointServiceTransport]] + _transport_registry['grpc'] = IndexEndpointServiceGrpcTransport + _transport_registry['grpc_asyncio'] = IndexEndpointServiceGrpcAsyncIOTransport + + def get_transport_class(cls, + label: str = None, + ) -> Type[IndexEndpointServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class IndexEndpointServiceClient(metaclass=IndexEndpointServiceClientMeta): + """A service for managing AI Platform's IndexEndpoints.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexEndpointServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexEndpointServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> IndexEndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + IndexEndpointServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def index_path(project: str,location: str,index: str,) -> str: + """Return a fully-qualified index string.""" + return "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + + @staticmethod + def parse_index_path(path: str) -> Dict[str,str]: + """Parse a index path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexes/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def index_endpoint_path(project: str,location: str,index_endpoint: str,) -> str: + """Return a fully-qualified index_endpoint string.""" + return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + + @staticmethod + def parse_index_endpoint_path(path: str) -> Dict[str,str]: + """Parse a index_endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def index_endpoint_path(project: str,location: str,index_endpoint: str,) -> str: + """Return a fully-qualified index_endpoint string.""" + return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + + @staticmethod + def parse_index_endpoint_path(path: str) -> Dict[str,str]: + """Parse a index_endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, IndexEndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the index endpoint service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, IndexEndpointServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, IndexEndpointServiceTransport): + # transport is a IndexEndpointServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_index_endpoint(self, + request: index_endpoint_service.CreateIndexEndpointRequest = None, + *, + parent: str = None, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Creates an IndexEndpoint. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateIndexEndpointRequest): + The request object. Request message for + ``IndexEndpointService.CreateIndexEndpoint``. + parent (str): + Required. The resource name of the Location to create + the IndexEndpoint in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + index_endpoint (google.cloud.aiplatform_v1beta1.types.IndexEndpoint): + Required. The IndexEndpoint to + create. + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.IndexEndpoint` Indexes are deployed into it. An IndexEndpoint can have multiple + DeployedIndexes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, index_endpoint]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_endpoint_service.CreateIndexEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_endpoint_service.CreateIndexEndpointRequest): + request = index_endpoint_service.CreateIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if index_endpoint is not None: + request.index_endpoint = index_endpoint + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_index_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_index_endpoint.IndexEndpoint, + metadata_type=index_endpoint_service.CreateIndexEndpointOperationMetadata, + ) + + # Done; return the response. + return response + + def get_index_endpoint(self, + request: index_endpoint_service.GetIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index_endpoint.IndexEndpoint: + r"""Gets an IndexEndpoint. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetIndexEndpointRequest): + The request object. Request message for + ``IndexEndpointService.GetIndexEndpoint`` + name (str): + Required. The name of the IndexEndpoint resource. + Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.IndexEndpoint: + Indexes are deployed into it. An + IndexEndpoint can have multiple + DeployedIndexes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_endpoint_service.GetIndexEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_endpoint_service.GetIndexEndpointRequest): + request = index_endpoint_service.GetIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_index_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_index_endpoints(self, + request: index_endpoint_service.ListIndexEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexEndpointsPager: + r"""Lists IndexEndpoints in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsRequest): + The request object. Request message for + ``IndexEndpointService.ListIndexEndpoints``. + parent (str): + Required. The resource name of the Location from which + to list the IndexEndpoints. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.index_endpoint_service.pagers.ListIndexEndpointsPager: + Response message for + ``IndexEndpointService.ListIndexEndpoints``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_endpoint_service.ListIndexEndpointsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_endpoint_service.ListIndexEndpointsRequest): + request = index_endpoint_service.ListIndexEndpointsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_index_endpoints] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListIndexEndpointsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_index_endpoint(self, + request: index_endpoint_service.UpdateIndexEndpointRequest = None, + *, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_index_endpoint.IndexEndpoint: + r"""Updates an IndexEndpoint. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateIndexEndpointRequest): + The request object. Request message for + ``IndexEndpointService.UpdateIndexEndpoint``. + index_endpoint (google.cloud.aiplatform_v1beta1.types.IndexEndpoint): + Required. The IndexEndpoint which + replaces the resource on the server. + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.IndexEndpoint: + Indexes are deployed into it. An + IndexEndpoint can have multiple + DeployedIndexes. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index_endpoint, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_endpoint_service.UpdateIndexEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_endpoint_service.UpdateIndexEndpointRequest): + request = index_endpoint_service.UpdateIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index_endpoint is not None: + request.index_endpoint = index_endpoint + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_index_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index_endpoint.name', request.index_endpoint.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_index_endpoint(self, + request: index_endpoint_service.DeleteIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes an IndexEndpoint. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteIndexEndpointRequest): + The request object. Request message for + ``IndexEndpointService.DeleteIndexEndpoint``. + name (str): + Required. The name of the IndexEndpoint resource to be + deleted. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_endpoint_service.DeleteIndexEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_endpoint_service.DeleteIndexEndpointRequest): + request = index_endpoint_service.DeleteIndexEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_index_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def deploy_index(self, + request: index_endpoint_service.DeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index: gca_index_endpoint.DeployedIndex = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deploys an Index into this IndexEndpoint, creating a + DeployedIndex within it. + Only non-empty Indexes can be deployed. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeployIndexRequest): + The request object. Request message for + ``IndexEndpointService.DeployIndex``. + index_endpoint (str): + Required. The name of the IndexEndpoint resource into + which to deploy an Index. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_index (google.cloud.aiplatform_v1beta1.types.DeployedIndex): + Required. The DeployedIndex to be + created within the IndexEndpoint. + + This corresponds to the ``deployed_index`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.DeployIndexResponse` + Response message for + ``IndexEndpointService.DeployIndex``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index_endpoint, deployed_index]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_endpoint_service.DeployIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_endpoint_service.DeployIndexRequest): + request = index_endpoint_service.DeployIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index_endpoint is not None: + request.index_endpoint = index_endpoint + if deployed_index is not None: + request.deployed_index = deployed_index + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.deploy_index] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index_endpoint', request.index_endpoint), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + index_endpoint_service.DeployIndexResponse, + metadata_type=index_endpoint_service.DeployIndexOperationMetadata, + ) + + # Done; return the response. + return response + + def undeploy_index(self, + request: index_endpoint_service.UndeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Undeploys an Index from an IndexEndpoint, removing a + DeployedIndex from it, and freeing all resources it's + using. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UndeployIndexRequest): + The request object. Request message for + ``IndexEndpointService.UndeployIndex``. + index_endpoint (str): + Required. The name of the IndexEndpoint resource from + which to undeploy an Index. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + + This corresponds to the ``index_endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_index_id (str): + Required. The ID of the DeployedIndex + to be undeployed from the IndexEndpoint. + + This corresponds to the ``deployed_index_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.UndeployIndexResponse` + Response message for + ``IndexEndpointService.UndeployIndex``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index_endpoint, deployed_index_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_endpoint_service.UndeployIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_endpoint_service.UndeployIndexRequest): + request = index_endpoint_service.UndeployIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index_endpoint is not None: + request.index_endpoint = index_endpoint + if deployed_index_id is not None: + request.deployed_index_id = deployed_index_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.undeploy_index] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index_endpoint', request.index_endpoint), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + index_endpoint_service.UndeployIndexResponse, + metadata_type=index_endpoint_service.UndeployIndexOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'IndexEndpointServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py new file mode 100644 index 0000000000..7c38beadfd --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional + +from google.cloud.aiplatform_v1beta1.types import index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint_service + + +class ListIndexEndpointsPager: + """A pager for iterating through ``list_index_endpoints`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``index_endpoints`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListIndexEndpoints`` requests and continue to iterate + through the ``index_endpoints`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., index_endpoint_service.ListIndexEndpointsResponse], + request: index_endpoint_service.ListIndexEndpointsRequest, + response: index_endpoint_service.ListIndexEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = index_endpoint_service.ListIndexEndpointsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[index_endpoint_service.ListIndexEndpointsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[index_endpoint.IndexEndpoint]: + for page in self.pages: + yield from page.index_endpoints + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListIndexEndpointsAsyncPager: + """A pager for iterating through ``list_index_endpoints`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``index_endpoints`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListIndexEndpoints`` requests and continue to iterate + through the ``index_endpoints`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[index_endpoint_service.ListIndexEndpointsResponse]], + request: index_endpoint_service.ListIndexEndpointsRequest, + response: index_endpoint_service.ListIndexEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = index_endpoint_service.ListIndexEndpointsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[index_endpoint_service.ListIndexEndpointsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[index_endpoint.IndexEndpoint]: + async def async_generator(): + async for page in self.pages: + for response in page.index_endpoints: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py new file mode 100644 index 0000000000..dd025dddb8 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import IndexEndpointServiceTransport +from .grpc import IndexEndpointServiceGrpcTransport +from .grpc_asyncio import IndexEndpointServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[IndexEndpointServiceTransport]] +_transport_registry['grpc'] = IndexEndpointServiceGrpcTransport +_transport_registry['grpc_asyncio'] = IndexEndpointServiceGrpcAsyncIOTransport + +__all__ = ( + 'IndexEndpointServiceTransport', + 'IndexEndpointServiceGrpcTransport', + 'IndexEndpointServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py new file mode 100644 index 0000000000..e16f56dd80 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class IndexEndpointServiceTransport(abc.ABC): + """Abstract transport class for IndexEndpointService.""" + + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) + + def __init__( + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + + # Save the credentials. + self._credentials = credentials + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_index_endpoint: gapic_v1.method.wrap_method( + self.create_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.get_index_endpoint: gapic_v1.method.wrap_method( + self.get_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.list_index_endpoints: gapic_v1.method.wrap_method( + self.list_index_endpoints, + default_timeout=None, + client_info=client_info, + ), + self.update_index_endpoint: gapic_v1.method.wrap_method( + self.update_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.delete_index_endpoint: gapic_v1.method.wrap_method( + self.delete_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.deploy_index: gapic_v1.method.wrap_method( + self.deploy_index, + default_timeout=None, + client_info=client_info, + ), + self.undeploy_index: gapic_v1.method.wrap_method( + self.undeploy_index, + default_timeout=None, + client_info=client_info, + ), + + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_index_endpoint(self) -> typing.Callable[ + [index_endpoint_service.CreateIndexEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def get_index_endpoint(self) -> typing.Callable[ + [index_endpoint_service.GetIndexEndpointRequest], + typing.Union[ + index_endpoint.IndexEndpoint, + typing.Awaitable[index_endpoint.IndexEndpoint] + ]]: + raise NotImplementedError() + + @property + def list_index_endpoints(self) -> typing.Callable[ + [index_endpoint_service.ListIndexEndpointsRequest], + typing.Union[ + index_endpoint_service.ListIndexEndpointsResponse, + typing.Awaitable[index_endpoint_service.ListIndexEndpointsResponse] + ]]: + raise NotImplementedError() + + @property + def update_index_endpoint(self) -> typing.Callable[ + [index_endpoint_service.UpdateIndexEndpointRequest], + typing.Union[ + gca_index_endpoint.IndexEndpoint, + typing.Awaitable[gca_index_endpoint.IndexEndpoint] + ]]: + raise NotImplementedError() + + @property + def delete_index_endpoint(self) -> typing.Callable[ + [index_endpoint_service.DeleteIndexEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def deploy_index(self) -> typing.Callable[ + [index_endpoint_service.DeployIndexRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def undeploy_index(self) -> typing.Callable[ + [index_endpoint_service.UndeployIndexRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'IndexEndpointServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py new file mode 100644 index 0000000000..274c8cdc6f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py @@ -0,0 +1,429 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import IndexEndpointServiceTransport, DEFAULT_CLIENT_INFO + + +class IndexEndpointServiceGrpcTransport(IndexEndpointServiceTransport): + """gRPC backend transport for IndexEndpointService. + + A service for managing AI Platform's IndexEndpoints. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_index_endpoint(self) -> Callable[ + [index_endpoint_service.CreateIndexEndpointRequest], + operations.Operation]: + r"""Return a callable for the create index endpoint method over gRPC. + + Creates an IndexEndpoint. + + Returns: + Callable[[~.CreateIndexEndpointRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_index_endpoint' not in self._stubs: + self._stubs['create_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/CreateIndexEndpoint', + request_serializer=index_endpoint_service.CreateIndexEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_index_endpoint'] + + @property + def get_index_endpoint(self) -> Callable[ + [index_endpoint_service.GetIndexEndpointRequest], + index_endpoint.IndexEndpoint]: + r"""Return a callable for the get index endpoint method over gRPC. + + Gets an IndexEndpoint. + + Returns: + Callable[[~.GetIndexEndpointRequest], + ~.IndexEndpoint]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_index_endpoint' not in self._stubs: + self._stubs['get_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/GetIndexEndpoint', + request_serializer=index_endpoint_service.GetIndexEndpointRequest.serialize, + response_deserializer=index_endpoint.IndexEndpoint.deserialize, + ) + return self._stubs['get_index_endpoint'] + + @property + def list_index_endpoints(self) -> Callable[ + [index_endpoint_service.ListIndexEndpointsRequest], + index_endpoint_service.ListIndexEndpointsResponse]: + r"""Return a callable for the list index endpoints method over gRPC. + + Lists IndexEndpoints in a Location. + + Returns: + Callable[[~.ListIndexEndpointsRequest], + ~.ListIndexEndpointsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_index_endpoints' not in self._stubs: + self._stubs['list_index_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/ListIndexEndpoints', + request_serializer=index_endpoint_service.ListIndexEndpointsRequest.serialize, + response_deserializer=index_endpoint_service.ListIndexEndpointsResponse.deserialize, + ) + return self._stubs['list_index_endpoints'] + + @property + def update_index_endpoint(self) -> Callable[ + [index_endpoint_service.UpdateIndexEndpointRequest], + gca_index_endpoint.IndexEndpoint]: + r"""Return a callable for the update index endpoint method over gRPC. + + Updates an IndexEndpoint. + + Returns: + Callable[[~.UpdateIndexEndpointRequest], + ~.IndexEndpoint]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_index_endpoint' not in self._stubs: + self._stubs['update_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UpdateIndexEndpoint', + request_serializer=index_endpoint_service.UpdateIndexEndpointRequest.serialize, + response_deserializer=gca_index_endpoint.IndexEndpoint.deserialize, + ) + return self._stubs['update_index_endpoint'] + + @property + def delete_index_endpoint(self) -> Callable[ + [index_endpoint_service.DeleteIndexEndpointRequest], + operations.Operation]: + r"""Return a callable for the delete index endpoint method over gRPC. + + Deletes an IndexEndpoint. + + Returns: + Callable[[~.DeleteIndexEndpointRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_index_endpoint' not in self._stubs: + self._stubs['delete_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeleteIndexEndpoint', + request_serializer=index_endpoint_service.DeleteIndexEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_index_endpoint'] + + @property + def deploy_index(self) -> Callable[ + [index_endpoint_service.DeployIndexRequest], + operations.Operation]: + r"""Return a callable for the deploy index method over gRPC. + + Deploys an Index into this IndexEndpoint, creating a + DeployedIndex within it. + Only non-empty Indexes can be deployed. + + Returns: + Callable[[~.DeployIndexRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'deploy_index' not in self._stubs: + self._stubs['deploy_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeployIndex', + request_serializer=index_endpoint_service.DeployIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['deploy_index'] + + @property + def undeploy_index(self) -> Callable[ + [index_endpoint_service.UndeployIndexRequest], + operations.Operation]: + r"""Return a callable for the undeploy index method over gRPC. + + Undeploys an Index from an IndexEndpoint, removing a + DeployedIndex from it, and freeing all resources it's + using. + + Returns: + Callable[[~.UndeployIndexRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'undeploy_index' not in self._stubs: + self._stubs['undeploy_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UndeployIndex', + request_serializer=index_endpoint_service.UndeployIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['undeploy_index'] + + +__all__ = ( + 'IndexEndpointServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..3b2c0fb5ce --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import IndexEndpointServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import IndexEndpointServiceGrpcTransport + + +class IndexEndpointServiceGrpcAsyncIOTransport(IndexEndpointServiceTransport): + """gRPC AsyncIO backend transport for IndexEndpointService. + + A service for managing AI Platform's IndexEndpoints. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_index_endpoint(self) -> Callable[ + [index_endpoint_service.CreateIndexEndpointRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create index endpoint method over gRPC. + + Creates an IndexEndpoint. + + Returns: + Callable[[~.CreateIndexEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_index_endpoint' not in self._stubs: + self._stubs['create_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/CreateIndexEndpoint', + request_serializer=index_endpoint_service.CreateIndexEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_index_endpoint'] + + @property + def get_index_endpoint(self) -> Callable[ + [index_endpoint_service.GetIndexEndpointRequest], + Awaitable[index_endpoint.IndexEndpoint]]: + r"""Return a callable for the get index endpoint method over gRPC. + + Gets an IndexEndpoint. + + Returns: + Callable[[~.GetIndexEndpointRequest], + Awaitable[~.IndexEndpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_index_endpoint' not in self._stubs: + self._stubs['get_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/GetIndexEndpoint', + request_serializer=index_endpoint_service.GetIndexEndpointRequest.serialize, + response_deserializer=index_endpoint.IndexEndpoint.deserialize, + ) + return self._stubs['get_index_endpoint'] + + @property + def list_index_endpoints(self) -> Callable[ + [index_endpoint_service.ListIndexEndpointsRequest], + Awaitable[index_endpoint_service.ListIndexEndpointsResponse]]: + r"""Return a callable for the list index endpoints method over gRPC. + + Lists IndexEndpoints in a Location. + + Returns: + Callable[[~.ListIndexEndpointsRequest], + Awaitable[~.ListIndexEndpointsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_index_endpoints' not in self._stubs: + self._stubs['list_index_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/ListIndexEndpoints', + request_serializer=index_endpoint_service.ListIndexEndpointsRequest.serialize, + response_deserializer=index_endpoint_service.ListIndexEndpointsResponse.deserialize, + ) + return self._stubs['list_index_endpoints'] + + @property + def update_index_endpoint(self) -> Callable[ + [index_endpoint_service.UpdateIndexEndpointRequest], + Awaitable[gca_index_endpoint.IndexEndpoint]]: + r"""Return a callable for the update index endpoint method over gRPC. + + Updates an IndexEndpoint. + + Returns: + Callable[[~.UpdateIndexEndpointRequest], + Awaitable[~.IndexEndpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_index_endpoint' not in self._stubs: + self._stubs['update_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UpdateIndexEndpoint', + request_serializer=index_endpoint_service.UpdateIndexEndpointRequest.serialize, + response_deserializer=gca_index_endpoint.IndexEndpoint.deserialize, + ) + return self._stubs['update_index_endpoint'] + + @property + def delete_index_endpoint(self) -> Callable[ + [index_endpoint_service.DeleteIndexEndpointRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete index endpoint method over gRPC. + + Deletes an IndexEndpoint. + + Returns: + Callable[[~.DeleteIndexEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_index_endpoint' not in self._stubs: + self._stubs['delete_index_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeleteIndexEndpoint', + request_serializer=index_endpoint_service.DeleteIndexEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_index_endpoint'] + + @property + def deploy_index(self) -> Callable[ + [index_endpoint_service.DeployIndexRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the deploy index method over gRPC. + + Deploys an Index into this IndexEndpoint, creating a + DeployedIndex within it. + Only non-empty Indexes can be deployed. + + Returns: + Callable[[~.DeployIndexRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'deploy_index' not in self._stubs: + self._stubs['deploy_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeployIndex', + request_serializer=index_endpoint_service.DeployIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['deploy_index'] + + @property + def undeploy_index(self) -> Callable[ + [index_endpoint_service.UndeployIndexRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the undeploy index method over gRPC. + + Undeploys an Index from an IndexEndpoint, removing a + DeployedIndex from it, and freeing all resources it's + using. + + Returns: + Callable[[~.UndeployIndexRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'undeploy_index' not in self._stubs: + self._stubs['undeploy_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UndeployIndex', + request_serializer=index_endpoint_service.UndeployIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['undeploy_index'] + + +__all__ = ( + 'IndexEndpointServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_service/__init__.py new file mode 100644 index 0000000000..5b6569d841 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import IndexServiceClient +from .async_client import IndexServiceAsyncClient + +__all__ = ( + 'IndexServiceClient', + 'IndexServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py new file mode 100644 index 0000000000..49fc00f568 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py @@ -0,0 +1,653 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.index_service import pagers +from google.cloud.aiplatform_v1beta1.types import deployed_index_ref +from google.cloud.aiplatform_v1beta1.types import index +from google.cloud.aiplatform_v1beta1.types import index as gca_index +from google.cloud.aiplatform_v1beta1.types import index_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import IndexServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import IndexServiceGrpcAsyncIOTransport +from .client import IndexServiceClient + + +class IndexServiceAsyncClient: + """A service for creating and managing AI Platform's Index + resources. + """ + + _client: IndexServiceClient + + DEFAULT_ENDPOINT = IndexServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = IndexServiceClient.DEFAULT_MTLS_ENDPOINT + + index_path = staticmethod(IndexServiceClient.index_path) + parse_index_path = staticmethod(IndexServiceClient.parse_index_path) + index_endpoint_path = staticmethod(IndexServiceClient.index_endpoint_path) + parse_index_endpoint_path = staticmethod(IndexServiceClient.parse_index_endpoint_path) + + common_billing_account_path = staticmethod(IndexServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(IndexServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(IndexServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(IndexServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(IndexServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(IndexServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(IndexServiceClient.common_project_path) + parse_common_project_path = staticmethod(IndexServiceClient.parse_common_project_path) + + common_location_path = staticmethod(IndexServiceClient.common_location_path) + parse_common_location_path = staticmethod(IndexServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexServiceAsyncClient: The constructed client. + """ + return IndexServiceClient.from_service_account_info.__func__(IndexServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexServiceAsyncClient: The constructed client. + """ + return IndexServiceClient.from_service_account_file.__func__(IndexServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> IndexServiceTransport: + """Return the transport used by the client instance. + + Returns: + IndexServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(IndexServiceClient).get_transport_class, type(IndexServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, IndexServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the index service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.IndexServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = IndexServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_index(self, + request: index_service.CreateIndexRequest = None, + *, + parent: str = None, + index: gca_index.Index = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates an Index. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateIndexRequest`): + The request object. Request message for + ``IndexService.CreateIndex``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Index in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + index (:class:`google.cloud.aiplatform_v1beta1.types.Index`): + Required. The Index to create. + This corresponds to the ``index`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Index` A representation of a collection of database items organized in a way that + allows for approximate nearest neighbor (a.k.a ANN) + algorithms search. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, index]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_service.CreateIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if index is not None: + request.index = index + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_index, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_index.Index, + metadata_type=index_service.CreateIndexOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_index(self, + request: index_service.GetIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index.Index: + r"""Gets an Index. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetIndexRequest`): + The request object. Request message for + ``IndexService.GetIndex`` + name (:class:`str`): + Required. The name of the Index resource. Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Index: + A representation of a collection of + database items organized in a way that + allows for approximate nearest neighbor + (a.k.a ANN) algorithms search. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_service.GetIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_index, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_indexes(self, + request: index_service.ListIndexesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexesAsyncPager: + r"""Lists Indexes in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListIndexesRequest`): + The request object. Request message for + ``IndexService.ListIndexes``. + parent (:class:`str`): + Required. The resource name of the Location from which + to list the Indexes. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.index_service.pagers.ListIndexesAsyncPager: + Response message for + ``IndexService.ListIndexes``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_service.ListIndexesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_indexes, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListIndexesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_index(self, + request: index_service.UpdateIndexRequest = None, + *, + index: gca_index.Index = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates an Index. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateIndexRequest`): + The request object. Request message for + [IndexService.UpdateModel][]. + index (:class:`google.cloud.aiplatform_v1beta1.types.Index`): + Required. The Index which updates the + resource on the server. + + This corresponds to the ``index`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + The update mask applies to the resource. For the + ``FieldMask`` definition, see + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Index` A representation of a collection of database items organized in a way that + allows for approximate nearest neighbor (a.k.a ANN) + algorithms search. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_service.UpdateIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index is not None: + request.index = index + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_index, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index.name', request.index.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_index.Index, + metadata_type=index_service.UpdateIndexOperationMetadata, + ) + + # Done; return the response. + return response + + async def delete_index(self, + request: index_service.DeleteIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an Index. An Index can only be deleted when all its + ``DeployedIndexes`` + had been undeployed. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteIndexRequest`): + The request object. Request message for + ``IndexService.DeleteIndex``. + name (:class:`str`): + Required. The name of the Index resource to be deleted. + Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = index_service.DeleteIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_index, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'IndexServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_service/client.py new file mode 100644 index 0000000000..133cf63a94 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/client.py @@ -0,0 +1,839 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.index_service import pagers +from google.cloud.aiplatform_v1beta1.types import deployed_index_ref +from google.cloud.aiplatform_v1beta1.types import index +from google.cloud.aiplatform_v1beta1.types import index as gca_index +from google.cloud.aiplatform_v1beta1.types import index_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import IndexServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import IndexServiceGrpcTransport +from .transports.grpc_asyncio import IndexServiceGrpcAsyncIOTransport + + +class IndexServiceClientMeta(type): + """Metaclass for the IndexService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[IndexServiceTransport]] + _transport_registry['grpc'] = IndexServiceGrpcTransport + _transport_registry['grpc_asyncio'] = IndexServiceGrpcAsyncIOTransport + + def get_transport_class(cls, + label: str = None, + ) -> Type[IndexServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class IndexServiceClient(metaclass=IndexServiceClientMeta): + """A service for creating and managing AI Platform's Index + resources. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IndexServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> IndexServiceTransport: + """Return the transport used by the client instance. + + Returns: + IndexServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def index_path(project: str,location: str,index: str,) -> str: + """Return a fully-qualified index string.""" + return "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + + @staticmethod + def parse_index_path(path: str) -> Dict[str,str]: + """Parse a index path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexes/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def index_endpoint_path(project: str,location: str,index_endpoint: str,) -> str: + """Return a fully-qualified index_endpoint string.""" + return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + + @staticmethod + def parse_index_endpoint_path(path: str) -> Dict[str,str]: + """Parse a index_endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, IndexServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the index service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, IndexServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, IndexServiceTransport): + # transport is a IndexServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_index(self, + request: index_service.CreateIndexRequest = None, + *, + parent: str = None, + index: gca_index.Index = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Creates an Index. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateIndexRequest): + The request object. Request message for + ``IndexService.CreateIndex``. + parent (str): + Required. The resource name of the Location to create + the Index in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + index (google.cloud.aiplatform_v1beta1.types.Index): + Required. The Index to create. + This corresponds to the ``index`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Index` A representation of a collection of database items organized in a way that + allows for approximate nearest neighbor (a.k.a ANN) + algorithms search. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, index]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_service.CreateIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_service.CreateIndexRequest): + request = index_service.CreateIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if index is not None: + request.index = index + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_index] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_index.Index, + metadata_type=index_service.CreateIndexOperationMetadata, + ) + + # Done; return the response. + return response + + def get_index(self, + request: index_service.GetIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index.Index: + r"""Gets an Index. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetIndexRequest): + The request object. Request message for + ``IndexService.GetIndex`` + name (str): + Required. The name of the Index resource. Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Index: + A representation of a collection of + database items organized in a way that + allows for approximate nearest neighbor + (a.k.a ANN) algorithms search. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_service.GetIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_service.GetIndexRequest): + request = index_service.GetIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_index] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_indexes(self, + request: index_service.ListIndexesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexesPager: + r"""Lists Indexes in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListIndexesRequest): + The request object. Request message for + ``IndexService.ListIndexes``. + parent (str): + Required. The resource name of the Location from which + to list the Indexes. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.index_service.pagers.ListIndexesPager: + Response message for + ``IndexService.ListIndexes``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_service.ListIndexesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_service.ListIndexesRequest): + request = index_service.ListIndexesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_indexes] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListIndexesPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_index(self, + request: index_service.UpdateIndexRequest = None, + *, + index: gca_index.Index = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Updates an Index. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateIndexRequest): + The request object. Request message for + [IndexService.UpdateModel][]. + index (google.cloud.aiplatform_v1beta1.types.Index): + Required. The Index which updates the + resource on the server. + + This corresponds to the ``index`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + The update mask applies to the resource. For the + ``FieldMask`` definition, see + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Index` A representation of a collection of database items organized in a way that + allows for approximate nearest neighbor (a.k.a ANN) + algorithms search. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([index, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_service.UpdateIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_service.UpdateIndexRequest): + request = index_service.UpdateIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if index is not None: + request.index = index + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_index] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('index.name', request.index.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_index.Index, + metadata_type=index_service.UpdateIndexOperationMetadata, + ) + + # Done; return the response. + return response + + def delete_index(self, + request: index_service.DeleteIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes an Index. An Index can only be deleted when all its + ``DeployedIndexes`` + had been undeployed. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteIndexRequest): + The request object. Request message for + ``IndexService.DeleteIndex``. + name (str): + Required. The name of the Index resource to be deleted. + Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a index_service.DeleteIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, index_service.DeleteIndexRequest): + request = index_service.DeleteIndexRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_index] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'IndexServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py new file mode 100644 index 0000000000..dea7e37830 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional + +from google.cloud.aiplatform_v1beta1.types import index +from google.cloud.aiplatform_v1beta1.types import index_service + + +class ListIndexesPager: + """A pager for iterating through ``list_indexes`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListIndexesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``indexes`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListIndexes`` requests and continue to iterate + through the ``indexes`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListIndexesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., index_service.ListIndexesResponse], + request: index_service.ListIndexesRequest, + response: index_service.ListIndexesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListIndexesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListIndexesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = index_service.ListIndexesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[index_service.ListIndexesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[index.Index]: + for page in self.pages: + yield from page.indexes + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListIndexesAsyncPager: + """A pager for iterating through ``list_indexes`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListIndexesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``indexes`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListIndexes`` requests and continue to iterate + through the ``indexes`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListIndexesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[index_service.ListIndexesResponse]], + request: index_service.ListIndexesRequest, + response: index_service.ListIndexesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListIndexesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListIndexesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = index_service.ListIndexesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[index_service.ListIndexesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[index.Index]: + async def async_generator(): + async for page in self.pages: + for response in page.indexes: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py new file mode 100644 index 0000000000..7bb2e2abad --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import IndexServiceTransport +from .grpc import IndexServiceGrpcTransport +from .grpc_asyncio import IndexServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[IndexServiceTransport]] +_transport_registry['grpc'] = IndexServiceGrpcTransport +_transport_registry['grpc_asyncio'] = IndexServiceGrpcAsyncIOTransport + +__all__ = ( + 'IndexServiceTransport', + 'IndexServiceGrpcTransport', + 'IndexServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py new file mode 100644 index 0000000000..fd218d13dd --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import index +from google.cloud.aiplatform_v1beta1.types import index_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class IndexServiceTransport(abc.ABC): + """Abstract transport class for IndexService.""" + + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) + + def __init__( + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=self._scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + + # Save the credentials. + self._credentials = credentials + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_index: gapic_v1.method.wrap_method( + self.create_index, + default_timeout=None, + client_info=client_info, + ), + self.get_index: gapic_v1.method.wrap_method( + self.get_index, + default_timeout=None, + client_info=client_info, + ), + self.list_indexes: gapic_v1.method.wrap_method( + self.list_indexes, + default_timeout=None, + client_info=client_info, + ), + self.update_index: gapic_v1.method.wrap_method( + self.update_index, + default_timeout=None, + client_info=client_info, + ), + self.delete_index: gapic_v1.method.wrap_method( + self.delete_index, + default_timeout=None, + client_info=client_info, + ), + + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_index(self) -> typing.Callable[ + [index_service.CreateIndexRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def get_index(self) -> typing.Callable[ + [index_service.GetIndexRequest], + typing.Union[ + index.Index, + typing.Awaitable[index.Index] + ]]: + raise NotImplementedError() + + @property + def list_indexes(self) -> typing.Callable[ + [index_service.ListIndexesRequest], + typing.Union[ + index_service.ListIndexesResponse, + typing.Awaitable[index_service.ListIndexesResponse] + ]]: + raise NotImplementedError() + + @property + def update_index(self) -> typing.Callable[ + [index_service.UpdateIndexRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + @property + def delete_index(self) -> typing.Callable[ + [index_service.DeleteIndexRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'IndexServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py new file mode 100644 index 0000000000..783ab5733f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py @@ -0,0 +1,375 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import index +from google.cloud.aiplatform_v1beta1.types import index_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import IndexServiceTransport, DEFAULT_CLIENT_INFO + + +class IndexServiceGrpcTransport(IndexServiceTransport): + """gRPC backend transport for IndexService. + + A service for creating and managing AI Platform's Index + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_index(self) -> Callable[ + [index_service.CreateIndexRequest], + operations.Operation]: + r"""Return a callable for the create index method over gRPC. + + Creates an Index. + + Returns: + Callable[[~.CreateIndexRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_index' not in self._stubs: + self._stubs['create_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/CreateIndex', + request_serializer=index_service.CreateIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_index'] + + @property + def get_index(self) -> Callable[ + [index_service.GetIndexRequest], + index.Index]: + r"""Return a callable for the get index method over gRPC. + + Gets an Index. + + Returns: + Callable[[~.GetIndexRequest], + ~.Index]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_index' not in self._stubs: + self._stubs['get_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/GetIndex', + request_serializer=index_service.GetIndexRequest.serialize, + response_deserializer=index.Index.deserialize, + ) + return self._stubs['get_index'] + + @property + def list_indexes(self) -> Callable[ + [index_service.ListIndexesRequest], + index_service.ListIndexesResponse]: + r"""Return a callable for the list indexes method over gRPC. + + Lists Indexes in a Location. + + Returns: + Callable[[~.ListIndexesRequest], + ~.ListIndexesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_indexes' not in self._stubs: + self._stubs['list_indexes'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/ListIndexes', + request_serializer=index_service.ListIndexesRequest.serialize, + response_deserializer=index_service.ListIndexesResponse.deserialize, + ) + return self._stubs['list_indexes'] + + @property + def update_index(self) -> Callable[ + [index_service.UpdateIndexRequest], + operations.Operation]: + r"""Return a callable for the update index method over gRPC. + + Updates an Index. + + Returns: + Callable[[~.UpdateIndexRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_index' not in self._stubs: + self._stubs['update_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/UpdateIndex', + request_serializer=index_service.UpdateIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['update_index'] + + @property + def delete_index(self) -> Callable[ + [index_service.DeleteIndexRequest], + operations.Operation]: + r"""Return a callable for the delete index method over gRPC. + + Deletes an Index. An Index can only be deleted when all its + ``DeployedIndexes`` + had been undeployed. + + Returns: + Callable[[~.DeleteIndexRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_index' not in self._stubs: + self._stubs['delete_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/DeleteIndex', + request_serializer=index_service.DeleteIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_index'] + + +__all__ = ( + 'IndexServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..e0287ff613 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import index +from google.cloud.aiplatform_v1beta1.types import index_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import IndexServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import IndexServiceGrpcTransport + + +class IndexServiceGrpcAsyncIOTransport(IndexServiceTransport): + """gRPC AsyncIO backend transport for IndexService. + + A service for creating and managing AI Platform's Index + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_index(self) -> Callable[ + [index_service.CreateIndexRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create index method over gRPC. + + Creates an Index. + + Returns: + Callable[[~.CreateIndexRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_index' not in self._stubs: + self._stubs['create_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/CreateIndex', + request_serializer=index_service.CreateIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_index'] + + @property + def get_index(self) -> Callable[ + [index_service.GetIndexRequest], + Awaitable[index.Index]]: + r"""Return a callable for the get index method over gRPC. + + Gets an Index. + + Returns: + Callable[[~.GetIndexRequest], + Awaitable[~.Index]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_index' not in self._stubs: + self._stubs['get_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/GetIndex', + request_serializer=index_service.GetIndexRequest.serialize, + response_deserializer=index.Index.deserialize, + ) + return self._stubs['get_index'] + + @property + def list_indexes(self) -> Callable[ + [index_service.ListIndexesRequest], + Awaitable[index_service.ListIndexesResponse]]: + r"""Return a callable for the list indexes method over gRPC. + + Lists Indexes in a Location. + + Returns: + Callable[[~.ListIndexesRequest], + Awaitable[~.ListIndexesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_indexes' not in self._stubs: + self._stubs['list_indexes'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/ListIndexes', + request_serializer=index_service.ListIndexesRequest.serialize, + response_deserializer=index_service.ListIndexesResponse.deserialize, + ) + return self._stubs['list_indexes'] + + @property + def update_index(self) -> Callable[ + [index_service.UpdateIndexRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the update index method over gRPC. + + Updates an Index. + + Returns: + Callable[[~.UpdateIndexRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_index' not in self._stubs: + self._stubs['update_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/UpdateIndex', + request_serializer=index_service.UpdateIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['update_index'] + + @property + def delete_index(self) -> Callable[ + [index_service.DeleteIndexRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete index method over gRPC. + + Deletes an Index. An Index can only be deleted when all its + ``DeployedIndexes`` + had been undeployed. + + Returns: + Callable[[~.DeleteIndexRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_index' not in self._stubs: + self._stubs['delete_index'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.IndexService/DeleteIndex', + request_serializer=index_service.DeleteIndexRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_index'] + + +__all__ = ( + 'IndexServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 17a2d7a221..f8d4a0e95d 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -75,6 +75,9 @@ ListDatasetsResponse, UpdateDatasetRequest, ) +from .deployed_index_ref import ( + DeployedIndexRef, +) from .deployed_model_ref import ( DeployedModelRef, ) @@ -193,6 +196,41 @@ from .hyperparameter_tuning_job import ( HyperparameterTuningJob, ) +from .index import ( + Index, +) +from .index_endpoint import ( + DeployedIndex, + DeployedIndexAuthConfig, + IndexEndpoint, + IndexPrivateEndpoints, +) +from .index_endpoint_service import ( + CreateIndexEndpointOperationMetadata, + CreateIndexEndpointRequest, + DeleteIndexEndpointRequest, + DeployIndexOperationMetadata, + DeployIndexRequest, + DeployIndexResponse, + GetIndexEndpointRequest, + ListIndexEndpointsRequest, + ListIndexEndpointsResponse, + UndeployIndexOperationMetadata, + UndeployIndexRequest, + UndeployIndexResponse, + UpdateIndexEndpointRequest, +) +from .index_service import ( + CreateIndexOperationMetadata, + CreateIndexRequest, + DeleteIndexRequest, + GetIndexRequest, + ListIndexesRequest, + ListIndexesResponse, + NearestNeighborSearchOperationMetadata, + UpdateIndexOperationMetadata, + UpdateIndexRequest, +) from .io import ( AvroSource, BigQueryDestination, @@ -476,6 +514,7 @@ 'ListDatasetsRequest', 'ListDatasetsResponse', 'UpdateDatasetRequest', + 'DeployedIndexRef', 'DeployedModelRef', 'EncryptionSpec', 'DeployedModel', @@ -558,6 +597,33 @@ 'UpdateFeaturestoreOperationMetadata', 'UpdateFeaturestoreRequest', 'HyperparameterTuningJob', + 'Index', + 'DeployedIndex', + 'DeployedIndexAuthConfig', + 'IndexEndpoint', + 'IndexPrivateEndpoints', + 'CreateIndexEndpointOperationMetadata', + 'CreateIndexEndpointRequest', + 'DeleteIndexEndpointRequest', + 'DeployIndexOperationMetadata', + 'DeployIndexRequest', + 'DeployIndexResponse', + 'GetIndexEndpointRequest', + 'ListIndexEndpointsRequest', + 'ListIndexEndpointsResponse', + 'UndeployIndexOperationMetadata', + 'UndeployIndexRequest', + 'UndeployIndexResponse', + 'UpdateIndexEndpointRequest', + 'CreateIndexOperationMetadata', + 'CreateIndexRequest', + 'DeleteIndexRequest', + 'GetIndexRequest', + 'ListIndexesRequest', + 'ListIndexesResponse', + 'NearestNeighborSearchOperationMetadata', + 'UpdateIndexOperationMetadata', + 'UpdateIndexRequest', 'AvroSource', 'BigQueryDestination', 'BigQuerySource', diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py new file mode 100644 index 0000000000..eee6fd93f9 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'DeployedIndexRef', + }, +) + + +class DeployedIndexRef(proto.Message): + r"""Points to a DeployedIndex. + + Attributes: + index_endpoint (str): + Immutable. A resource name of the + IndexEndpoint. + deployed_index_id (str): + Immutable. The ID of the DeployedIndex in the + above IndexEndpoint. + """ + + index_endpoint = proto.Field(proto.STRING, number=1) + + deployed_index_id = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/index.py b/google/cloud/aiplatform_v1beta1/types/index.py new file mode 100644 index 0000000000..abf5ebf8ac --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/index.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import deployed_index_ref +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Index', + }, +) + + +class Index(proto.Message): + r"""A representation of a collection of database items organized + in a way that allows for approximate nearest neighbor (a.k.a + ANN) algorithms search. + + Attributes: + name (str): + Output only. The resource name of the Index. + display_name (str): + Required. The display name of the Index. + The name can be up to 128 characters long and + can be consist of any UTF-8 characters. + description (str): + The description of the Index. + metadata_schema_uri (str): + Immutable. Points to a YAML file stored on Google Cloud + Storage describing additional information about the Index, + that is specific to it. Unset if the Index does not have any + additional information. The schema is defined as an OpenAPI + 3.0.2 `Schema + Object `__. + Note: The URI given on output will be immutable and probably + different, including the URI scheme, than the one given on + input. The output URI will point to a location where the + user only has a read access. + metadata (google.protobuf.struct_pb2.Value): + An additional information about the Index; the schema of the + metadata can be found in + ``metadata_schema``. + deployed_indexes (Sequence[google.cloud.aiplatform_v1beta1.types.DeployedIndexRef]): + Output only. The pointers to DeployedIndexes + created from this Index. An Index can be only + deleted if all its DeployedIndexes had been + undeployed first. + etag (str): + Used to perform consistent read-modify-write + updates. If not set, a blind "overwrite" update + happens. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Index.LabelsEntry]): + The labels with user-defined metadata to + organize your Indexes. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Index was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Index was most recently + updated. This also includes any update to the contents of + the Index. Note that Operations working on this Index may + have their + [Operations.metadata.generic_metadata.update_time] + [google.cloud.aiplatform.v1beta1.GenericOperationMetadata.update_time] + a little after the value of this timestamp, yet that does + not mean their results are not already reflected in the + Index. Result of any successfully completed Operation on the + Index is reflected in it. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + metadata_schema_uri = proto.Field(proto.STRING, number=4) + + metadata = proto.Field(proto.MESSAGE, number=6, + message=struct.Value, + ) + + deployed_indexes = proto.RepeatedField(proto.MESSAGE, number=7, + message=deployed_index_ref.DeployedIndexRef, + ) + + etag = proto.Field(proto.STRING, number=8) + + labels = proto.MapField(proto.STRING, proto.STRING, number=9) + + create_time = proto.Field(proto.MESSAGE, number=10, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py new file mode 100644 index 0000000000..f1bbd1b62a --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'IndexEndpoint', + 'DeployedIndex', + 'DeployedIndexAuthConfig', + 'IndexPrivateEndpoints', + }, +) + + +class IndexEndpoint(proto.Message): + r"""Indexes are deployed into it. An IndexEndpoint can have + multiple DeployedIndexes. + + Attributes: + name (str): + Output only. The resource name of the + IndexEndpoint. + display_name (str): + Required. The display name of the + IndexEndpoint. The name can be up to 128 + characters long and can consist of any UTF-8 + characters. + description (str): + The description of the IndexEndpoint. + deployed_indexes (Sequence[google.cloud.aiplatform_v1beta1.types.DeployedIndex]): + Output only. The indexes deployed in this + endpoint. + etag (str): + Used to perform consistent read-modify-write + updates. If not set, a blind "overwrite" update + happens. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.IndexEndpoint.LabelsEntry]): + The labels with user-defined metadata to + organize your IndexEndpoints. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + IndexEndpoint was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + IndexEndpoint was last updated. This timestamp + is not updated when the endpoint's + DeployedIndexes are updated, e.g. due to updates + of the original Indexes they are the deployments + of. + network (str): + Required. Immutable. The full name of the Google Compute + Engine + `network `__ + to which the IndexEndpoint should be peered. + + Private services access must already be configured for the + network. If left unspecified, the Endpoint is not peered + with any network. + + `Format `__: + projects/{project}/global/networks/{network}. Where + {project} is a project number, as in '12345', and {network} + is network name. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + deployed_indexes = proto.RepeatedField(proto.MESSAGE, number=4, + message='DeployedIndex', + ) + + etag = proto.Field(proto.STRING, number=5) + + labels = proto.MapField(proto.STRING, proto.STRING, number=6) + + create_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) + + network = proto.Field(proto.STRING, number=9) + + +class DeployedIndex(proto.Message): + r"""A deployment of an Index. IndexEndpoints contain one or more + DeployedIndexes. + + Attributes: + id (str): + Required. The user specified ID of the + DeployedIndex. The ID can be up to 128 + characters long and must start with a letter and + only contain letters, numbers, and underscores. + The ID must be unique within the project it is + created in. + index (str): + Required. The name of the Index this is the + deployment of. We may refer to this Index as the + DeployedIndex's "original" Index. + display_name (str): + The display name of the DeployedIndex. If not provided upon + creation, the Index's display_name is used. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when the DeployedIndex + was created. + private_endpoints (google.cloud.aiplatform_v1beta1.types.IndexPrivateEndpoints): + Output only. Provides paths for users to send requests + directly to the deployed index services running on Cloud via + private services access. This field is populated if + ``network`` + is configured. + index_sync_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The DeployedIndex may depend on various data on + its original Index. Additionally when certain changes to the + original Index are being done (e.g. when what the Index + contains is being changed) the DeployedIndex may be + asynchronously updated in the background to reflect this + changes. If this timestamp's value is at least the + ``Index.update_time`` + of the original Index, it means that this DeployedIndex and + the original Index are in sync. If this timestamp is older, + then to see which updates this DeployedIndex already + contains (and which not), one must + ``list`` ``Operations`` + ``working`` on the original Index. Only the + successfully completed Operations with + [Operations.metadata.generic_metadata.update_time] + [google.cloud.aiplatform.v1beta1.GenericOperationMetadata.update_time] + equal or before this sync time are contained in this + DeployedIndex. + automatic_resources (google.cloud.aiplatform_v1beta1.types.AutomaticResources): + Optional. A description of resources that the DeployedIndex + uses, which to large degree are decided by AI Platform, and + optionally allows only a modest additional configuration. If + min_replica_count is not set, the default value is 1. If + max_replica_count is not set, the default value is + min_replica_count. The max allowed replica count is 1000. + + The user is billed for the resources (at least their minimal + amount) even if the DeployedIndex receives no traffic. + enable_access_logging (bool): + Optional. If true, private endpoint's access + logs are sent to StackDriver Logging. + These logs are like standard server access logs, + containing information like timestamp and + latency for each MatchRequest. + Note that Stackdriver logs may incur a cost, + especially if the deployed index receives a high + queries per second rate (QPS). Estimate your + costs before enabling this option. + deployed_index_auth_config (google.cloud.aiplatform_v1beta1.types.DeployedIndexAuthConfig): + Optional. If set, the authentication is + enabled for the private endpoint. + """ + + id = proto.Field(proto.STRING, number=1) + + index = proto.Field(proto.STRING, number=2) + + display_name = proto.Field(proto.STRING, number=3) + + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + + private_endpoints = proto.Field(proto.MESSAGE, number=5, + message='IndexPrivateEndpoints', + ) + + index_sync_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) + + automatic_resources = proto.Field(proto.MESSAGE, number=7, + message=machine_resources.AutomaticResources, + ) + + enable_access_logging = proto.Field(proto.BOOL, number=8) + + deployed_index_auth_config = proto.Field(proto.MESSAGE, number=9, + message='DeployedIndexAuthConfig', + ) + + +class DeployedIndexAuthConfig(proto.Message): + r"""Used to set up the auth on the DeployedIndex's private + endpoint. + + Attributes: + auth_provider (google.cloud.aiplatform_v1beta1.types.DeployedIndexAuthConfig.AuthProvider): + Defines the authentication provider that the + DeployedIndex uses. + """ + class AuthProvider(proto.Message): + r"""Configuration for an authentication provider, including support for + `JSON Web Token + (JWT) `__. + + Attributes: + audiences (Sequence[str]): + The list of JWT + `audiences `__. + that are allowed to access. A JWT containing any of these + audiences will be accepted. + """ + + audiences = proto.RepeatedField(proto.STRING, number=1) + + auth_provider = proto.Field(proto.MESSAGE, number=1, + message=AuthProvider, + ) + + +class IndexPrivateEndpoints(proto.Message): + r"""IndexPrivateEndpoints proto is used to provide paths for + users to send requests via private services access. + + Attributes: + match_grpc_address (str): + Output only. The ip address used to send + match gRPC requests. + """ + + match_grpc_address = proto.Field(proto.STRING, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py new file mode 100644 index 0000000000..cf5abb0c5a --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint +from google.cloud.aiplatform_v1beta1.types import operation +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'CreateIndexEndpointRequest', + 'CreateIndexEndpointOperationMetadata', + 'GetIndexEndpointRequest', + 'ListIndexEndpointsRequest', + 'ListIndexEndpointsResponse', + 'UpdateIndexEndpointRequest', + 'DeleteIndexEndpointRequest', + 'DeployIndexRequest', + 'DeployIndexResponse', + 'DeployIndexOperationMetadata', + 'UndeployIndexRequest', + 'UndeployIndexResponse', + 'UndeployIndexOperationMetadata', + }, +) + + +class CreateIndexEndpointRequest(proto.Message): + r"""Request message for + ``IndexEndpointService.CreateIndexEndpoint``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + IndexEndpoint in. Format: + ``projects/{project}/locations/{location}`` + index_endpoint (google.cloud.aiplatform_v1beta1.types.IndexEndpoint): + Required. The IndexEndpoint to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + index_endpoint = proto.Field(proto.MESSAGE, number=2, + message=gca_index_endpoint.IndexEndpoint, + ) + + +class CreateIndexEndpointOperationMetadata(proto.Message): + r"""Runtime operation information for + ``IndexEndpointService.CreateIndexEndpoint``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class GetIndexEndpointRequest(proto.Message): + r"""Request message for + ``IndexEndpointService.GetIndexEndpoint`` + + Attributes: + name (str): + Required. The name of the IndexEndpoint resource. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListIndexEndpointsRequest(proto.Message): + r"""Request message for + ``IndexEndpointService.ListIndexEndpoints``. + + Attributes: + parent (str): + Required. The resource name of the Location from which to + list the IndexEndpoints. Format: + ``projects/{project}/locations/{location}`` + filter (str): + Optional. An expression for filtering the results of the + request. For field names both snake_case and camelCase are + supported. + + - ``index_endpoint`` supports = and !=. ``index_endpoint`` + represents the IndexEndpoint ID, ie. the last segment of + the IndexEndpoint's + ``resourcename``. + - ``display_name`` supports =, != and regex() (uses + `re2 `__ + syntax) + - ``labels`` supports general map functions that is: + ``labels.key=value`` - key:value equality + ``labels.key:* or labels:key - key existence A key including a space must be quoted.``\ labels."a + key"`. + + Some examples: + + - ``index_endpoint="1"`` + - ``display_name="myDisplayName"`` + - \`regex(display_name, "^A") -> The display name starts + with an A. + - ``labels.myKey="myValue"`` + page_size (int): + Optional. The standard list page size. + page_token (str): + Optional. The standard list page token. Typically obtained + via + ``ListIndexEndpointsResponse.next_page_token`` + of the previous + ``IndexEndpointService.ListIndexEndpoints`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Optional. Mask specifying which fields to + read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) + + +class ListIndexEndpointsResponse(proto.Message): + r"""Response message for + ``IndexEndpointService.ListIndexEndpoints``. + + Attributes: + index_endpoints (Sequence[google.cloud.aiplatform_v1beta1.types.IndexEndpoint]): + List of IndexEndpoints in the requested page. + next_page_token (str): + A token to retrieve next page of results. Pass to + ``ListIndexEndpointsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + index_endpoints = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_index_endpoint.IndexEndpoint, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateIndexEndpointRequest(proto.Message): + r"""Request message for + ``IndexEndpointService.UpdateIndexEndpoint``. + + Attributes: + index_endpoint (google.cloud.aiplatform_v1beta1.types.IndexEndpoint): + Required. The IndexEndpoint which replaces + the resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + `FieldMask `__. + """ + + index_endpoint = proto.Field(proto.MESSAGE, number=1, + message=gca_index_endpoint.IndexEndpoint, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + +class DeleteIndexEndpointRequest(proto.Message): + r"""Request message for + ``IndexEndpointService.DeleteIndexEndpoint``. + + Attributes: + name (str): + Required. The name of the IndexEndpoint resource to be + deleted. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class DeployIndexRequest(proto.Message): + r"""Request message for + ``IndexEndpointService.DeployIndex``. + + Attributes: + index_endpoint (str): + Required. The name of the IndexEndpoint resource into which + to deploy an Index. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + deployed_index (google.cloud.aiplatform_v1beta1.types.DeployedIndex): + Required. The DeployedIndex to be created + within the IndexEndpoint. + """ + + index_endpoint = proto.Field(proto.STRING, number=1) + + deployed_index = proto.Field(proto.MESSAGE, number=2, + message=gca_index_endpoint.DeployedIndex, + ) + + +class DeployIndexResponse(proto.Message): + r"""Response message for + ``IndexEndpointService.DeployIndex``. + + Attributes: + deployed_index (google.cloud.aiplatform_v1beta1.types.DeployedIndex): + The DeployedIndex that had been deployed in + the IndexEndpoint. + """ + + deployed_index = proto.Field(proto.MESSAGE, number=1, + message=gca_index_endpoint.DeployedIndex, + ) + + +class DeployIndexOperationMetadata(proto.Message): + r"""Runtime operation information for + ``IndexEndpointService.DeployIndex``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +class UndeployIndexRequest(proto.Message): + r"""Request message for + ``IndexEndpointService.UndeployIndex``. + + Attributes: + index_endpoint (str): + Required. The name of the IndexEndpoint resource from which + to undeploy an Index. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + deployed_index_id (str): + Required. The ID of the DeployedIndex to be + undeployed from the IndexEndpoint. + """ + + index_endpoint = proto.Field(proto.STRING, number=1) + + deployed_index_id = proto.Field(proto.STRING, number=2) + + +class UndeployIndexResponse(proto.Message): + r"""Response message for + ``IndexEndpointService.UndeployIndex``. + """ + + +class UndeployIndexOperationMetadata(proto.Message): + r"""Runtime operation information for + ``IndexEndpointService.UndeployIndex``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/index_service.py b/google/cloud/aiplatform_v1beta1/types/index_service.py new file mode 100644 index 0000000000..56cb293e93 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/index_service.py @@ -0,0 +1,303 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import index as gca_index +from google.cloud.aiplatform_v1beta1.types import operation +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'CreateIndexRequest', + 'CreateIndexOperationMetadata', + 'GetIndexRequest', + 'ListIndexesRequest', + 'ListIndexesResponse', + 'UpdateIndexRequest', + 'UpdateIndexOperationMetadata', + 'DeleteIndexRequest', + 'NearestNeighborSearchOperationMetadata', + }, +) + + +class CreateIndexRequest(proto.Message): + r"""Request message for + ``IndexService.CreateIndex``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + Index in. Format: + ``projects/{project}/locations/{location}`` + index (google.cloud.aiplatform_v1beta1.types.Index): + Required. The Index to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + index = proto.Field(proto.MESSAGE, number=2, + message=gca_index.Index, + ) + + +class CreateIndexOperationMetadata(proto.Message): + r"""Runtime operation information for + ``IndexService.CreateIndex``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + nearest_neighbor_search_operation_metadata (google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata): + The operation metadata with regard to ScaNN + Index operation. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + nearest_neighbor_search_operation_metadata = proto.Field(proto.MESSAGE, number=2, + message='NearestNeighborSearchOperationMetadata', + ) + + +class GetIndexRequest(proto.Message): + r"""Request message for + ``IndexService.GetIndex`` + + Attributes: + name (str): + Required. The name of the Index resource. Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListIndexesRequest(proto.Message): + r"""Request message for + ``IndexService.ListIndexes``. + + Attributes: + parent (str): + Required. The resource name of the Location from which to + list the Indexes. Format: + ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListIndexesResponse.next_page_token`` + of the previous + ``IndexService.ListIndexes`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) + + +class ListIndexesResponse(proto.Message): + r"""Response message for + ``IndexService.ListIndexes``. + + Attributes: + indexes (Sequence[google.cloud.aiplatform_v1beta1.types.Index]): + List of indexes in the requested page. + next_page_token (str): + A token to retrieve next page of results. Pass to + ``ListIndexesRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + indexes = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_index.Index, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateIndexRequest(proto.Message): + r"""Request message for [IndexService.UpdateModel][]. + + Attributes: + index (google.cloud.aiplatform_v1beta1.types.Index): + Required. The Index which updates the + resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + The update mask applies to the resource. For the + ``FieldMask`` definition, see + `FieldMask `__. + """ + + index = proto.Field(proto.MESSAGE, number=1, + message=gca_index.Index, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) + + +class UpdateIndexOperationMetadata(proto.Message): + r"""Runtime operation information for + ``IndexService.UpdateIndex``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + nearest_neighbor_search_operation_metadata (google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata): + The operation metadata with regard to ScaNN + Index operation. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + nearest_neighbor_search_operation_metadata = proto.Field(proto.MESSAGE, number=2, + message='NearestNeighborSearchOperationMetadata', + ) + + +class DeleteIndexRequest(proto.Message): + r"""Request message for + ``IndexService.DeleteIndex``. + + Attributes: + name (str): + Required. The name of the Index resource to be deleted. + Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class NearestNeighborSearchOperationMetadata(proto.Message): + r"""Runtime operation metadata with regard to ScaNN Index. + + Attributes: + content_validation_stats (Sequence[google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata.ContentValidationStats]): + The validation stats of the content (per file) to be + inserted or updated on the ScaNN Index resource. Populated + if contentsDeltaUri is provided as part of + ``Index.metadata``. + Please note that, currently for those files that are broken + or has unsupported file format, we will not have the stats + for those files. + """ + class RecordError(proto.Message): + r""" + + Attributes: + error_type (google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata.RecordError.RecordErrorType): + The error type of this record. + error_message (str): + A human-readable message that is shown to the user to help + them fix the error. Note that this message may change from + time to time, your code should check against error_type as + the source of truth. + source_gcs_uri (str): + GCS uri pointing to the original file in + user's bucket. + embedding_id (str): + Empty if the embedding id is failed to parse. + raw_record (str): + The original content of this record. + """ + class RecordErrorType(proto.Enum): + r"""""" + ERROR_TYPE_UNSPECIFIED = 0 + EMPTY_LINE = 1 + INVALID_JSON_SYNTAX = 2 + INVALID_CSV_SYNTAX = 3 + INVALID_AVRO_SYNTAX = 4 + INVALID_EMBEDDING_ID = 5 + EMBEDDING_SIZE_MISMATCH = 6 + NAMESPACE_MISSING = 7 + + error_type = proto.Field(proto.ENUM, number=1, + enum='NearestNeighborSearchOperationMetadata.RecordError.RecordErrorType', + ) + + error_message = proto.Field(proto.STRING, number=2) + + source_gcs_uri = proto.Field(proto.STRING, number=3) + + embedding_id = proto.Field(proto.STRING, number=4) + + raw_record = proto.Field(proto.STRING, number=5) + + class ContentValidationStats(proto.Message): + r""" + + Attributes: + source_gcs_uri (str): + GCS uri pointing to the original file in + user's bucket. + valid_record_count (int): + Number of records in this file that were + successfully processed. + invalid_record_count (int): + Number of records in this file we skipped due + to validate errors. + partial_errors (Sequence[google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata.RecordError]): + The detail information of the partial + failures encountered for those invalid records + that couldn't be parsed. Up to 50 partial errors + will be reported. + """ + + source_gcs_uri = proto.Field(proto.STRING, number=1) + + valid_record_count = proto.Field(proto.INT64, number=2) + + invalid_record_count = proto.Field(proto.INT64, number=3) + + partial_errors = proto.RepeatedField(proto.MESSAGE, number=4, + message='NearestNeighborSearchOperationMetadata.RecordError', + ) + + content_validation_stats = proto.RepeatedField(proto.MESSAGE, number=1, + message=ContentValidationStats, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/samples/snippets/create_custom_job_sample_test.py b/samples/snippets/create_custom_job_sample_test.py index 212dd41e3c..0a29132cdc 100644 --- a/samples/snippets/create_custom_job_sample_test.py +++ b/samples/snippets/create_custom_job_sample_test.py @@ -22,7 +22,7 @@ import helpers PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") -CONTAINER_IMAGE_URI = "gcr.io/ucaip-test/ucaip-training-test:latest" +CONTAINER_IMAGE_URI = "gcr.io/ucaip-sample-tests/ucaip-training-test:latest" @pytest.fixture(scope="function", autouse=True) diff --git a/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py b/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py index f430fc38ed..7bb5ec5ac3 100644 --- a/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py +++ b/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py @@ -26,7 +26,7 @@ ) EXECUTOR_IMAGE_URI = "us.gcr.io/cloud-aiplatform/training/tf-gpu.2-1:latest" -PACKAGE_URI = "gs://ucaip-test-us-central1/training/pythonpackages/trainer.tar.bz2" +PACKAGE_URI = "gs://cloud-samples-data-us-central1/ai-platform-unified/training/python-packages/trainer.tar.bz2" PYTHON_MODULE = "trainer.hptuning_trainer" diff --git a/samples/snippets/create_hyperparameter_tuning_job_sample_test.py b/samples/snippets/create_hyperparameter_tuning_job_sample_test.py index ad1f0ae4db..9a16bdcb9c 100644 --- a/samples/snippets/create_hyperparameter_tuning_job_sample_test.py +++ b/samples/snippets/create_hyperparameter_tuning_job_sample_test.py @@ -21,7 +21,7 @@ import helpers PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") -CONTAINER_IMAGE_URI = "gcr.io/ucaip-test/ucaip-training-test:latest" +CONTAINER_IMAGE_URI = "gcr.io/ucaip-sample-tests/ucaip-training-test:latest" @pytest.fixture(scope="function", autouse=True) diff --git a/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py b/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py index 82725f3847..2323163c9e 100644 --- a/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py +++ b/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py @@ -30,7 +30,7 @@ ANNOTATION_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml" TRAINING_CONTAINER_SPEC_IMAGE_URI = ( - "gcr.io/ucaip-test/custom-container-managed-dataset:latest" + "gcr.io/ucaip-sample-tests/custom-container-managed-dataset:latest" ) MODEL_CONTAINER_SPEC_IMAGE_URI = "gcr.io/cloud-aiplatform/prediction/tf-gpu.1-15:latest" diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py new file mode 100644 index 0000000000..c8209a3cae --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -0,0 +1,2863 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import IndexEndpointServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import IndexEndpointServiceClient +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import transports +from google.cloud.aiplatform_v1beta1.types import index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint +from google.cloud.aiplatform_v1beta1.types import index_endpoint_service +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert IndexEndpointServiceClient._get_default_mtls_endpoint(None) is None + assert IndexEndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert IndexEndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert IndexEndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert IndexEndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert IndexEndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [ + IndexEndpointServiceClient, + IndexEndpointServiceAsyncClient, +]) +def test_index_endpoint_service_client_from_service_account_info(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +@pytest.mark.parametrize("client_class", [ + IndexEndpointServiceClient, + IndexEndpointServiceAsyncClient, +]) +def test_index_endpoint_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_index_endpoint_service_client_get_transport_class(): + transport = IndexEndpointServiceClient.get_transport_class() + available_transports = [ + transports.IndexEndpointServiceGrpcTransport, + ] + assert transport in available_transports + + transport = IndexEndpointServiceClient.get_transport_class("grpc") + assert transport == transports.IndexEndpointServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc"), + (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(IndexEndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceClient)) +@mock.patch.object(IndexEndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceAsyncClient)) +def test_index_endpoint_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(IndexEndpointServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(IndexEndpointServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + + (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc", "true"), + (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc", "false"), + (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(IndexEndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceClient)) +@mock.patch.object(IndexEndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_index_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc"), + (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_index_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc"), + (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_index_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_index_endpoint_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = IndexEndpointServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.CreateIndexEndpointRequest): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.CreateIndexEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_index_endpoint_from_dict(): + test_create_index_endpoint(request_type=dict) + + +def test_create_index_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index_endpoint), + '__call__') as call: + client.create_index_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.CreateIndexEndpointRequest() + +@pytest.mark.asyncio +async def test_create_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.CreateIndexEndpointRequest): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.CreateIndexEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_index_endpoint_async_from_dict(): + await test_create_index_endpoint_async(request_type=dict) + + +def test_create_index_endpoint_field_headers(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.CreateIndexEndpointRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_index_endpoint_field_headers_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.CreateIndexEndpointRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_index_endpoint_flattened(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_index_endpoint( + parent='parent_value', + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + + +def test_create_index_endpoint_flattened_error(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_index_endpoint( + index_endpoint_service.CreateIndexEndpointRequest(), + parent='parent_value', + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_index_endpoint_flattened_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_index_endpoint( + parent='parent_value', + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + + +@pytest.mark.asyncio +async def test_create_index_endpoint_flattened_error_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_index_endpoint( + index_endpoint_service.CreateIndexEndpointRequest(), + parent='parent_value', + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + ) + + +def test_get_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.GetIndexEndpointRequest): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_endpoint.IndexEndpoint( + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + + network='network_value', + + ) + + response = client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.GetIndexEndpointRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, index_endpoint.IndexEndpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + assert response.network == 'network_value' + + +def test_get_index_endpoint_from_dict(): + test_get_index_endpoint(request_type=dict) + + +def test_get_index_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index_endpoint), + '__call__') as call: + client.get_index_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.GetIndexEndpointRequest() + +@pytest.mark.asyncio +async def test_get_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.GetIndexEndpointRequest): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint.IndexEndpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + network='network_value', + )) + + response = await client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.GetIndexEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, index_endpoint.IndexEndpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + assert response.network == 'network_value' + + +@pytest.mark.asyncio +async def test_get_index_endpoint_async_from_dict(): + await test_get_index_endpoint_async(request_type=dict) + + +def test_get_index_endpoint_field_headers(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.GetIndexEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index_endpoint), + '__call__') as call: + call.return_value = index_endpoint.IndexEndpoint() + + client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_index_endpoint_field_headers_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.GetIndexEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint.IndexEndpoint()) + + await client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_index_endpoint_flattened(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_endpoint.IndexEndpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_index_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_index_endpoint_flattened_error(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_index_endpoint( + index_endpoint_service.GetIndexEndpointRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_index_endpoint_flattened_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_endpoint.IndexEndpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint.IndexEndpoint()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_index_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_index_endpoint_flattened_error_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_index_endpoint( + index_endpoint_service.GetIndexEndpointRequest(), + name='name_value', + ) + + +def test_list_index_endpoints(transport: str = 'grpc', request_type=index_endpoint_service.ListIndexEndpointsRequest): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_endpoint_service.ListIndexEndpointsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.ListIndexEndpointsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListIndexEndpointsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_index_endpoints_from_dict(): + test_list_index_endpoints(request_type=dict) + + +def test_list_index_endpoints_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + client.list_index_endpoints() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.ListIndexEndpointsRequest() + +@pytest.mark.asyncio +async def test_list_index_endpoints_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.ListIndexEndpointsRequest): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint_service.ListIndexEndpointsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.ListIndexEndpointsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListIndexEndpointsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_index_endpoints_async_from_dict(): + await test_list_index_endpoints_async(request_type=dict) + + +def test_list_index_endpoints_field_headers(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.ListIndexEndpointsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + call.return_value = index_endpoint_service.ListIndexEndpointsResponse() + + client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_index_endpoints_field_headers_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.ListIndexEndpointsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint_service.ListIndexEndpointsResponse()) + + await client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_index_endpoints_flattened(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_endpoint_service.ListIndexEndpointsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_index_endpoints( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_index_endpoints_flattened_error(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_index_endpoints( + index_endpoint_service.ListIndexEndpointsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_index_endpoints_flattened_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_endpoint_service.ListIndexEndpointsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint_service.ListIndexEndpointsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_index_endpoints( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_index_endpoints_flattened_error_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_index_endpoints( + index_endpoint_service.ListIndexEndpointsRequest(), + parent='parent_value', + ) + + +def test_list_index_endpoints_pager(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + next_page_token='abc', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[], + next_page_token='def', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + ], + next_page_token='ghi', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_index_endpoints(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, index_endpoint.IndexEndpoint) + for i in results) + +def test_list_index_endpoints_pages(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + next_page_token='abc', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[], + next_page_token='def', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + ], + next_page_token='ghi', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + ), + RuntimeError, + ) + pages = list(client.list_index_endpoints(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_index_endpoints_async_pager(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + next_page_token='abc', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[], + next_page_token='def', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + ], + next_page_token='ghi', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_index_endpoints(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, index_endpoint.IndexEndpoint) + for i in responses) + +@pytest.mark.asyncio +async def test_list_index_endpoints_async_pages(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_index_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + next_page_token='abc', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[], + next_page_token='def', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + ], + next_page_token='ghi', + ), + index_endpoint_service.ListIndexEndpointsResponse( + index_endpoints=[ + index_endpoint.IndexEndpoint(), + index_endpoint.IndexEndpoint(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_index_endpoints(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.UpdateIndexEndpointRequest): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_index_endpoint.IndexEndpoint( + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + + network='network_value', + + ) + + response = client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_index_endpoint.IndexEndpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + assert response.network == 'network_value' + + +def test_update_index_endpoint_from_dict(): + test_update_index_endpoint(request_type=dict) + + +def test_update_index_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index_endpoint), + '__call__') as call: + client.update_index_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() + +@pytest.mark.asyncio +async def test_update_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.UpdateIndexEndpointRequest): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_index_endpoint.IndexEndpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + network='network_value', + )) + + response = await client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_index_endpoint.IndexEndpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + assert response.network == 'network_value' + + +@pytest.mark.asyncio +async def test_update_index_endpoint_async_from_dict(): + await test_update_index_endpoint_async(request_type=dict) + + +def test_update_index_endpoint_field_headers(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.UpdateIndexEndpointRequest() + request.index_endpoint.name = 'index_endpoint.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index_endpoint), + '__call__') as call: + call.return_value = gca_index_endpoint.IndexEndpoint() + + client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index_endpoint.name=index_endpoint.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_index_endpoint_field_headers_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.UpdateIndexEndpointRequest() + request.index_endpoint.name = 'index_endpoint.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_index_endpoint.IndexEndpoint()) + + await client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index_endpoint.name=index_endpoint.name/value', + ) in kw['metadata'] + + +def test_update_index_endpoint_flattened(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_index_endpoint.IndexEndpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_index_endpoint( + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_index_endpoint_flattened_error(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_index_endpoint( + index_endpoint_service.UpdateIndexEndpointRequest(), + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_index_endpoint_flattened_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_index_endpoint.IndexEndpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_index_endpoint.IndexEndpoint()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_index_endpoint( + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_index_endpoint_flattened_error_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_index_endpoint( + index_endpoint_service.UpdateIndexEndpointRequest(), + index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.DeleteIndexEndpointRequest): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.DeleteIndexEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_index_endpoint_from_dict(): + test_delete_index_endpoint(request_type=dict) + + +def test_delete_index_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index_endpoint), + '__call__') as call: + client.delete_index_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.DeleteIndexEndpointRequest() + +@pytest.mark.asyncio +async def test_delete_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.DeleteIndexEndpointRequest): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.DeleteIndexEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_index_endpoint_async_from_dict(): + await test_delete_index_endpoint_async(request_type=dict) + + +def test_delete_index_endpoint_field_headers(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.DeleteIndexEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_index_endpoint_field_headers_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.DeleteIndexEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_index_endpoint_flattened(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_index_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_index_endpoint_flattened_error(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_index_endpoint( + index_endpoint_service.DeleteIndexEndpointRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_index_endpoint_flattened_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_index_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_index_endpoint_flattened_error_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_index_endpoint( + index_endpoint_service.DeleteIndexEndpointRequest(), + name='name_value', + ) + + +def test_deploy_index(transport: str = 'grpc', request_type=index_endpoint_service.DeployIndexRequest): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.DeployIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_deploy_index_from_dict(): + test_deploy_index(request_type=dict) + + +def test_deploy_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_index), + '__call__') as call: + client.deploy_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.DeployIndexRequest() + +@pytest.mark.asyncio +async def test_deploy_index_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.DeployIndexRequest): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.DeployIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_deploy_index_async_from_dict(): + await test_deploy_index_async(request_type=dict) + + +def test_deploy_index_field_headers(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.DeployIndexRequest() + request.index_endpoint = 'index_endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_index), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index_endpoint=index_endpoint/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_deploy_index_field_headers_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.DeployIndexRequest() + request.index_endpoint = 'index_endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_index), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index_endpoint=index_endpoint/value', + ) in kw['metadata'] + + +def test_deploy_index_flattened(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.deploy_index( + index_endpoint='index_endpoint_value', + deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].index_endpoint == 'index_endpoint_value' + + assert args[0].deployed_index == gca_index_endpoint.DeployedIndex(id='id_value') + + +def test_deploy_index_flattened_error(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.deploy_index( + index_endpoint_service.DeployIndexRequest(), + index_endpoint='index_endpoint_value', + deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + ) + + +@pytest.mark.asyncio +async def test_deploy_index_flattened_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.deploy_index( + index_endpoint='index_endpoint_value', + deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].index_endpoint == 'index_endpoint_value' + + assert args[0].deployed_index == gca_index_endpoint.DeployedIndex(id='id_value') + + +@pytest.mark.asyncio +async def test_deploy_index_flattened_error_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.deploy_index( + index_endpoint_service.DeployIndexRequest(), + index_endpoint='index_endpoint_value', + deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + ) + + +def test_undeploy_index(transport: str = 'grpc', request_type=index_endpoint_service.UndeployIndexRequest): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.UndeployIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_undeploy_index_from_dict(): + test_undeploy_index(request_type=dict) + + +def test_undeploy_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_index), + '__call__') as call: + client.undeploy_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.UndeployIndexRequest() + +@pytest.mark.asyncio +async def test_undeploy_index_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.UndeployIndexRequest): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_endpoint_service.UndeployIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_undeploy_index_async_from_dict(): + await test_undeploy_index_async(request_type=dict) + + +def test_undeploy_index_field_headers(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.UndeployIndexRequest() + request.index_endpoint = 'index_endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_index), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index_endpoint=index_endpoint/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_undeploy_index_field_headers_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_endpoint_service.UndeployIndexRequest() + request.index_endpoint = 'index_endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_index), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index_endpoint=index_endpoint/value', + ) in kw['metadata'] + + +def test_undeploy_index_flattened(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.undeploy_index( + index_endpoint='index_endpoint_value', + deployed_index_id='deployed_index_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].index_endpoint == 'index_endpoint_value' + + assert args[0].deployed_index_id == 'deployed_index_id_value' + + +def test_undeploy_index_flattened_error(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.undeploy_index( + index_endpoint_service.UndeployIndexRequest(), + index_endpoint='index_endpoint_value', + deployed_index_id='deployed_index_id_value', + ) + + +@pytest.mark.asyncio +async def test_undeploy_index_flattened_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.undeploy_index( + index_endpoint='index_endpoint_value', + deployed_index_id='deployed_index_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].index_endpoint == 'index_endpoint_value' + + assert args[0].deployed_index_id == 'deployed_index_id_value' + + +@pytest.mark.asyncio +async def test_undeploy_index_flattened_error_async(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.undeploy_index( + index_endpoint_service.UndeployIndexRequest(), + index_endpoint='index_endpoint_value', + deployed_index_id='deployed_index_id_value', + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.IndexEndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.IndexEndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = IndexEndpointServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.IndexEndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = IndexEndpointServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.IndexEndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = IndexEndpointServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.IndexEndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.IndexEndpointServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.IndexEndpointServiceGrpcTransport, + transports.IndexEndpointServiceGrpcAsyncIOTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.IndexEndpointServiceGrpcTransport, + ) + + +def test_index_endpoint_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.IndexEndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_index_endpoint_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.IndexEndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_index_endpoint', + 'get_index_endpoint', + 'list_index_endpoints', + 'update_index_endpoint', + 'delete_index_endpoint', + 'deploy_index', + 'undeploy_index', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_index_endpoint_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.IndexEndpointServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_index_endpoint_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.IndexEndpointServiceTransport() + adc.assert_called_once() + + +def test_index_endpoint_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + IndexEndpointServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_index_endpoint_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.IndexEndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize("transport_class", [transports.IndexEndpointServiceGrpcTransport, transports.IndexEndpointServiceGrpcAsyncIOTransport]) +def test_index_endpoint_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + + +def test_index_endpoint_service_host_no_port(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_index_endpoint_service_host_with_port(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_index_endpoint_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.IndexEndpointServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_index_endpoint_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.IndexEndpointServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.IndexEndpointServiceGrpcTransport, transports.IndexEndpointServiceGrpcAsyncIOTransport]) +def test_index_endpoint_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.IndexEndpointServiceGrpcTransport, transports.IndexEndpointServiceGrpcAsyncIOTransport]) +def test_index_endpoint_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_index_endpoint_service_grpc_lro_client(): + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_index_endpoint_service_grpc_lro_async_client(): + client = IndexEndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_index_path(): + project = "squid" + location = "clam" + index = "whelk" + + expected = "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + actual = IndexEndpointServiceClient.index_path(project, location, index) + assert expected == actual + + +def test_parse_index_path(): + expected = { + "project": "octopus", + "location": "oyster", + "index": "nudibranch", + + } + path = IndexEndpointServiceClient.index_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_index_path(path) + assert expected == actual + +def test_index_endpoint_path(): + project = "cuttlefish" + location = "mussel" + index_endpoint = "winkle" + + expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + actual = IndexEndpointServiceClient.index_endpoint_path(project, location, index_endpoint) + assert expected == actual + + +def test_parse_index_endpoint_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "index_endpoint": "abalone", + + } + path = IndexEndpointServiceClient.index_endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_index_endpoint_path(path) + assert expected == actual + +def test_index_endpoint_path(): + project = "squid" + location = "clam" + index_endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + actual = IndexEndpointServiceClient.index_endpoint_path(project, location, index_endpoint) + assert expected == actual + + +def test_parse_index_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "index_endpoint": "nudibranch", + + } + path = IndexEndpointServiceClient.index_endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_index_endpoint_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = IndexEndpointServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + + } + path = IndexEndpointServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder, ) + actual = IndexEndpointServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + + } + path = IndexEndpointServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = IndexEndpointServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + + } + path = IndexEndpointServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project, ) + actual = IndexEndpointServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + + } + path = IndexEndpointServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = IndexEndpointServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + + } + path = IndexEndpointServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.IndexEndpointServiceTransport, '_prep_wrapped_messages') as prep: + client = IndexEndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.IndexEndpointServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = IndexEndpointServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py new file mode 100644 index 0000000000..416b2087cc --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py @@ -0,0 +1,2317 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.index_service import IndexServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.index_service import IndexServiceClient +from google.cloud.aiplatform_v1beta1.services.index_service import pagers +from google.cloud.aiplatform_v1beta1.services.index_service import transports +from google.cloud.aiplatform_v1beta1.types import deployed_index_ref +from google.cloud.aiplatform_v1beta1.types import index +from google.cloud.aiplatform_v1beta1.types import index as gca_index +from google.cloud.aiplatform_v1beta1.types import index_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert IndexServiceClient._get_default_mtls_endpoint(None) is None + assert IndexServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert IndexServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert IndexServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert IndexServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert IndexServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [ + IndexServiceClient, + IndexServiceAsyncClient, +]) +def test_index_service_client_from_service_account_info(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +@pytest.mark.parametrize("client_class", [ + IndexServiceClient, + IndexServiceAsyncClient, +]) +def test_index_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_index_service_client_get_transport_class(): + transport = IndexServiceClient.get_transport_class() + available_transports = [ + transports.IndexServiceGrpcTransport, + ] + assert transport in available_transports + + transport = IndexServiceClient.get_transport_class("grpc") + assert transport == transports.IndexServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +@mock.patch.object(IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient)) +@mock.patch.object(IndexServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceAsyncClient)) +def test_index_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(IndexServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(IndexServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc", "true"), + (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc", "false"), + (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + +]) +@mock.patch.object(IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient)) +@mock.patch.object(IndexServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_index_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_index_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio"), +]) +def test_index_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_index_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = IndexServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_index(transport: str = 'grpc', request_type=index_service.CreateIndexRequest): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.CreateIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_index_from_dict(): + test_create_index(request_type=dict) + + +def test_create_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index), + '__call__') as call: + client.create_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.CreateIndexRequest() + +@pytest.mark.asyncio +async def test_create_index_async(transport: str = 'grpc_asyncio', request_type=index_service.CreateIndexRequest): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.CreateIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_index_async_from_dict(): + await test_create_index_async(request_type=dict) + + +def test_create_index_field_headers(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.CreateIndexRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_index_field_headers_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.CreateIndexRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_index_flattened(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_index( + parent='parent_value', + index=gca_index.Index(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].index == gca_index.Index(name='name_value') + + +def test_create_index_flattened_error(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_index( + index_service.CreateIndexRequest(), + parent='parent_value', + index=gca_index.Index(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_index_flattened_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_index( + parent='parent_value', + index=gca_index.Index(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].index == gca_index.Index(name='name_value') + + +@pytest.mark.asyncio +async def test_create_index_flattened_error_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_index( + index_service.CreateIndexRequest(), + parent='parent_value', + index=gca_index.Index(name='name_value'), + ) + + +def test_get_index(transport: str = 'grpc', request_type=index_service.GetIndexRequest): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index.Index( + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + + ) + + response = client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.GetIndexRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, index.Index) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.etag == 'etag_value' + + +def test_get_index_from_dict(): + test_get_index(request_type=dict) + + +def test_get_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index), + '__call__') as call: + client.get_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.GetIndexRequest() + +@pytest.mark.asyncio +async def test_get_index_async(transport: str = 'grpc_asyncio', request_type=index_service.GetIndexRequest): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index.Index( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) + + response = await client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.GetIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, index.Index) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_get_index_async_from_dict(): + await test_get_index_async(request_type=dict) + + +def test_get_index_field_headers(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.GetIndexRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index), + '__call__') as call: + call.return_value = index.Index() + + client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_index_field_headers_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.GetIndexRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index.Index()) + + await client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_index_flattened(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index.Index() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_index( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_index_flattened_error(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_index( + index_service.GetIndexRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_index_flattened_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index.Index() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index.Index()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_index( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_index_flattened_error_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_index( + index_service.GetIndexRequest(), + name='name_value', + ) + + +def test_list_indexes(transport: str = 'grpc', request_type=index_service.ListIndexesRequest): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_service.ListIndexesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.ListIndexesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListIndexesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_indexes_from_dict(): + test_list_indexes(request_type=dict) + + +def test_list_indexes_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + client.list_indexes() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.ListIndexesRequest() + +@pytest.mark.asyncio +async def test_list_indexes_async(transport: str = 'grpc_asyncio', request_type=index_service.ListIndexesRequest): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_service.ListIndexesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.ListIndexesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListIndexesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_indexes_async_from_dict(): + await test_list_indexes_async(request_type=dict) + + +def test_list_indexes_field_headers(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.ListIndexesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + call.return_value = index_service.ListIndexesResponse() + + client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_indexes_field_headers_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.ListIndexesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_service.ListIndexesResponse()) + + await client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_indexes_flattened(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_service.ListIndexesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_indexes( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_indexes_flattened_error(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_indexes( + index_service.ListIndexesRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_indexes_flattened_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = index_service.ListIndexesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_service.ListIndexesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_indexes( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_indexes_flattened_error_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_indexes( + index_service.ListIndexesRequest(), + parent='parent_value', + ) + + +def test_list_indexes_pager(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + index.Index(), + ], + next_page_token='abc', + ), + index_service.ListIndexesResponse( + indexes=[], + next_page_token='def', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + ], + next_page_token='ghi', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_indexes(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, index.Index) + for i in results) + +def test_list_indexes_pages(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + index.Index(), + ], + next_page_token='abc', + ), + index_service.ListIndexesResponse( + indexes=[], + next_page_token='def', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + ], + next_page_token='ghi', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + ], + ), + RuntimeError, + ) + pages = list(client.list_indexes(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_indexes_async_pager(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + index.Index(), + ], + next_page_token='abc', + ), + index_service.ListIndexesResponse( + indexes=[], + next_page_token='def', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + ], + next_page_token='ghi', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_indexes(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, index.Index) + for i in responses) + +@pytest.mark.asyncio +async def test_list_indexes_async_pages(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_indexes), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + index.Index(), + ], + next_page_token='abc', + ), + index_service.ListIndexesResponse( + indexes=[], + next_page_token='def', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + ], + next_page_token='ghi', + ), + index_service.ListIndexesResponse( + indexes=[ + index.Index(), + index.Index(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_indexes(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_index(transport: str = 'grpc', request_type=index_service.UpdateIndexRequest): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.UpdateIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_index_from_dict(): + test_update_index(request_type=dict) + + +def test_update_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index), + '__call__') as call: + client.update_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.UpdateIndexRequest() + +@pytest.mark.asyncio +async def test_update_index_async(transport: str = 'grpc_asyncio', request_type=index_service.UpdateIndexRequest): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.UpdateIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_index_async_from_dict(): + await test_update_index_async(request_type=dict) + + +def test_update_index_field_headers(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.UpdateIndexRequest() + request.index.name = 'index.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index.name=index.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_index_field_headers_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.UpdateIndexRequest() + request.index.name = 'index.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'index.name=index.name/value', + ) in kw['metadata'] + + +def test_update_index_flattened(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_index( + index=gca_index.Index(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].index == gca_index.Index(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_index_flattened_error(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_index( + index_service.UpdateIndexRequest(), + index=gca_index.Index(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_index_flattened_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_index( + index=gca_index.Index(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].index == gca_index.Index(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_index_flattened_error_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_index( + index_service.UpdateIndexRequest(), + index=gca_index.Index(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_index(transport: str = 'grpc', request_type=index_service.DeleteIndexRequest): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.DeleteIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_index_from_dict(): + test_delete_index(request_type=dict) + + +def test_delete_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index), + '__call__') as call: + client.delete_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.DeleteIndexRequest() + +@pytest.mark.asyncio +async def test_delete_index_async(transport: str = 'grpc_asyncio', request_type=index_service.DeleteIndexRequest): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == index_service.DeleteIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_index_async_from_dict(): + await test_delete_index_async(request_type=dict) + + +def test_delete_index_field_headers(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.DeleteIndexRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_index_field_headers_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = index_service.DeleteIndexRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_index_flattened(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_index( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_index_flattened_error(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_index( + index_service.DeleteIndexRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_index_flattened_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_index), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_index( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_index_flattened_error_async(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_index( + index_service.DeleteIndexRequest(), + name='name_value', + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.IndexServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.IndexServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = IndexServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.IndexServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = IndexServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.IndexServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = IndexServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.IndexServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.IndexServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.IndexServiceGrpcTransport, + transports.IndexServiceGrpcAsyncIOTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.IndexServiceGrpcTransport, + ) + + +def test_index_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.IndexServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_index_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.IndexServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_index', + 'get_index', + 'list_indexes', + 'update_index', + 'delete_index', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_index_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.IndexServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_index_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.IndexServiceTransport() + adc.assert_called_once() + + +def test_index_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + IndexServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_index_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.IndexServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize("transport_class", [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport]) +def test_index_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + + +def test_index_service_host_no_port(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_index_service_host_with_port(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_index_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.IndexServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_index_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.IndexServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport]) +def test_index_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport]) +def test_index_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_index_service_grpc_lro_client(): + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_index_service_grpc_lro_async_client(): + client = IndexServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_index_path(): + project = "squid" + location = "clam" + index = "whelk" + + expected = "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + actual = IndexServiceClient.index_path(project, location, index) + assert expected == actual + + +def test_parse_index_path(): + expected = { + "project": "octopus", + "location": "oyster", + "index": "nudibranch", + + } + path = IndexServiceClient.index_path(**expected) + + # Check that the path construction is reversible. + actual = IndexServiceClient.parse_index_path(path) + assert expected == actual + +def test_index_endpoint_path(): + project = "cuttlefish" + location = "mussel" + index_endpoint = "winkle" + + expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + actual = IndexServiceClient.index_endpoint_path(project, location, index_endpoint) + assert expected == actual + + +def test_parse_index_endpoint_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "index_endpoint": "abalone", + + } + path = IndexServiceClient.index_endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = IndexServiceClient.parse_index_endpoint_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = IndexServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + + } + path = IndexServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = IndexServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder, ) + actual = IndexServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + + } + path = IndexServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = IndexServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = IndexServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + + } + path = IndexServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = IndexServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project, ) + actual = IndexServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + + } + path = IndexServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = IndexServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = IndexServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + + } + path = IndexServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = IndexServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.IndexServiceTransport, '_prep_wrapped_messages') as prep: + client = IndexServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.IndexServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = IndexServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) From 3082f97837a4f2c033c9db4ad752f51419ef17cf Mon Sep 17 00:00:00 2001 From: jialuzh <35091833+jialuzh@users.noreply.github.com> Date: Tue, 13 Apr 2021 14:14:58 -0700 Subject: [PATCH 09/36] chore: Remove stuff (#313) --- google/cloud/aiplatform/__init__.py | 8 ++------ google/cloud/aiplatform/initializer.py | 15 +++------------ google/cloud/aiplatform/metadata/metadata.py | 18 +++++++----------- tests/unit/aiplatform/test_initializer.py | 14 -------------- 4 files changed, 12 insertions(+), 43 deletions(-) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 73a6342b99..eb51f5b84b 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -48,12 +48,10 @@ """ init = initializer.global_config.init -log_param = metadata.metadata_service.log_param log_params = metadata.metadata_service.log_params -log_metric = metadata.metadata_service.log_metric log_metrics = metadata.metadata_service.log_metrics -set_experiment = metadata.metadata_service.set_experiment get_experiment = metadata.metadata_service.get_experiment +get_pipeline = metadata.metadata_service.get_pipeline set_run = metadata.metadata_service.set_run @@ -61,12 +59,10 @@ "explain", "gapic", "init", - "log_param", "log_params", - "log_metric", "log_metrics", "get_experiment", - "set_experiment", + "get_pipeline", "set_run", "AutoMLImageTrainingJob", "AutoMLTabularTrainingJob", diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index c6dc61dd2e..a07bf1e779 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -56,23 +56,21 @@ def init( project: Optional[str] = None, location: Optional[str] = None, experiment: Optional[str] = None, - run: Optional[str] = None, staging_bucket: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, encryption_spec_key_name: Optional[str] = None, ): - """Updates common initalization parameters with provided options. + """Updates common initialization parameters with provided options. Args: project (str): The default project to use when making API calls. location (str): The default location to use when making API calls. If not set defaults to us-central-1 experiment (str): The experiment name - run (str): The run name staging_bucket (str): The default staging bucket to use to stage artifacts when making API calls. In the form gs://... - credentials (google.auth.crendentials.Credentials): The default custom - credentials to use when making API calls. If not provided crendentials + credentials (google.auth.credentials.Credentials): The default custom + credentials to use when making API calls. If not provided credentials will be ascertained from the environment. encryption_spec_key_name (Optional[str]): Optional. The Cloud KMS resource identifier of the customer @@ -91,13 +89,6 @@ def init( self._location = location if experiment: metadata.metadata_service.set_experiment(experiment) - if run: - if not experiment: - raise ValueError( - "No experiment set. Provide an experiment for this run, e.g., aiplatform.init(" - "experiment='my-experiment')." - ) - metadata.metadata_service.set_run(run) if staging_bucket: self._staging_bucket = staging_bucket if credentials: diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index 069b7009a7..79640bf7e1 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -43,8 +43,8 @@ def set_experiment(self, experiment: str): def set_run(self, run: str): if not self._experiment: raise ValueError( - "No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') or " - "aiplatform.set_experiment(experiment='my-experiment') before trying to set_run. " + "No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') " + "before trying to set_run. " ) execution = _Execution.get_or_create( resource_id=run, @@ -53,9 +53,6 @@ def set_run(self, run: str): ) self._run = execution.name - def log_param(self, name: str, value: Union[float, int, str]): - return self.log_params({name: value}) - def log_params(self, params: Dict[str, Union[float, int, str]]): self._validate_experiment_and_run(method_name="log_params") execution = _Execution.get_or_create( @@ -66,9 +63,6 @@ def log_params(self, params: Dict[str, Union[float, int, str]]): execution.update(metadata=params) self._run = execution.name - def log_metric(self, name: str, value: Union[str, float, int]): - return self.log_metrics({name: value}) - def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): self._validate_experiment_and_run(method_name="log_metrics") # Only one metrics artifact for the (experiment, run) tuple. @@ -83,16 +77,18 @@ def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): def get_experiment(self, experiment: str): raise NotImplementedError("get_experiment not implemented") + def get_pipeline(self, pipeline: str): + raise NotImplementedError("get_pipeline not implemented") + def _validate_experiment_and_run(self, method_name: str): if not self._experiment: raise ValueError( f"No experiment set. Make sure to call aiplatform.init(experiment='my-experiment') " - f"or aiplatform.set_experiment(experiment='my-experiment') before trying to {method_name}. " + f"before trying to {method_name}. " ) if not self._run: raise ValueError( - f"No run set. Make sure to call aiplatform.init(experiment='my-experiment', " - f"run='my-run') or aiplatform.set_run('my-run') before trying to {method_name}. " + f"No run set. Make sure to call aiplatform.set_run('my-run') before trying to {method_name}. " ) diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index dcb33f5b42..088eb118b4 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -39,7 +39,6 @@ _TEST_LOCATION_2 = "europe-west4" _TEST_INVALID_LOCATION = "test-invalid-location" _TEST_EXPERIMENT = "test-experiment" -_TEST_RUN = "test-run" _TEST_STAGING_BUCKET = "test-bucket" @@ -77,19 +76,6 @@ def test_init_experiment_calls_metadata_service(self, set_experiment_mock): initializer.global_config.init(experiment=_TEST_EXPERIMENT) set_experiment_mock.assert_called_once_with(_TEST_EXPERIMENT) - def test_init_run_alone_without_experiment_raises(self): - with pytest.raises(ValueError): - initializer.global_config.init(run=_TEST_RUN) - - @patch.object(metadata_service, "set_run") - @patch.object(metadata_service, "set_experiment") - def test_init_experiment_and_run_calls_metadata_service( - self, set_experiment_mock, set_run_mock - ): - initializer.global_config.init(experiment=_TEST_EXPERIMENT, run=_TEST_RUN) - set_experiment_mock.assert_called_once_with(_TEST_EXPERIMENT) - set_run_mock.assert_called_once_with(_TEST_RUN) - def test_init_staging_bucket_sets_staging_bucket(self): initializer.global_config.init(staging_bucket=_TEST_STAGING_BUCKET) assert initializer.global_config.staging_bucket == _TEST_STAGING_BUCKET From 0e1b8c12ccc3473c96c8a2a805eea951ffcdb24d Mon Sep 17 00:00:00 2001 From: jialuzh <35091833+jialuzh@users.noreply.github.com> Date: Thu, 15 Apr 2021 11:39:53 -0700 Subject: [PATCH 10/36] feat: connect artifact, executions, and contexts (#314) --- google/cloud/aiplatform/metadata/artifact.py | 5 +- google/cloud/aiplatform/metadata/constants.py | 4 + google/cloud/aiplatform/metadata/context.py | 24 +- google/cloud/aiplatform/metadata/execution.py | 27 +- google/cloud/aiplatform/metadata/metadata.py | 42 ++- .../aiplatform/metadata/metadata_store.py | 5 +- google/cloud/aiplatform/metadata/resource.py | 14 +- tests/unit/aiplatform/test_metadata.py | 265 +++++++++++++ .../aiplatform/test_metadata_resources.py | 353 +++++++++++++++++- tests/unit/aiplatform/test_metadata_store.py | 13 +- 10 files changed, 701 insertions(+), 51 deletions(-) create mode 100644 tests/unit/aiplatform/test_metadata.py diff --git a/google/cloud/aiplatform/metadata/artifact.py b/google/cloud/aiplatform/metadata/artifact.py index 9f8505c71b..eb835cafd2 100644 --- a/google/cloud/aiplatform/metadata/artifact.py +++ b/google/cloud/aiplatform/metadata/artifact.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import proto + from typing import Optional, Dict +import proto + from google.cloud.aiplatform import utils from google.cloud.aiplatform.metadata.resource import _Resource - from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact diff --git a/google/cloud/aiplatform/metadata/constants.py b/google/cloud/aiplatform/metadata/constants.py index cc905e8b9e..bee9677cd5 100644 --- a/google/cloud/aiplatform/metadata/constants.py +++ b/google/cloud/aiplatform/metadata/constants.py @@ -26,3 +26,7 @@ SYSTEM_EXPERIMENT: _DEFAULT_SCHEMA_VERSION, SYSTEM_METRICS: _DEFAULT_SCHEMA_VERSION, } + +# The EXPERIMENT_METADATA is needed until we support context deletion in backend service. +# TODO: delete EXPERIMENT_METADATA once backend supports context deletion. +EXPERIMENT_METADATA = {"experiment_deleted": False, "experiment_type": "MB"} diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py index 11cb297365..76e6283d51 100644 --- a/google/cloud/aiplatform/metadata/context.py +++ b/google/cloud/aiplatform/metadata/context.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from typing import Optional, Dict, Sequence + import proto -from typing import Optional, Dict from google.cloud.aiplatform import utils from google.cloud.aiplatform.metadata.resource import _Resource - from google.cloud.aiplatform_v1beta1.types import context as gca_context @@ -57,3 +58,22 @@ def _update_resource( cls, client: utils.MetadataClientWithOverride, resource: proto.Message, ) -> proto.Message: return client.update_context(context=resource) + + def add_artifacts_and_executions( + self, + artifact_resource_names: Optional[Sequence[str]] = None, + execution_resource_names: Optional[Sequence[str]] = None, + ): + """Creates a new Metadata resource. + + Args: + artifact_resource_names (Sequence[str]): + Optional. The full resource name of Artifacts to attribute to the Context. + execution_resource_names (Sequence[str]): + Optional. The full resource name of Executions to associate with the Context. + """ + self.api_client.add_context_artifacts_and_executions( + context=self.resource_name, + artifacts=artifact_resource_names, + executions=execution_resource_names, + ) diff --git a/google/cloud/aiplatform/metadata/execution.py b/google/cloud/aiplatform/metadata/execution.py index 060d562660..6a173e07dd 100644 --- a/google/cloud/aiplatform/metadata/execution.py +++ b/google/cloud/aiplatform/metadata/execution.py @@ -14,12 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import proto + from typing import Optional, Dict +import proto + from google.cloud.aiplatform import utils from google.cloud.aiplatform.metadata.resource import _Resource - +from google.cloud.aiplatform_v1beta1 import Event from google.cloud.aiplatform_v1beta1.types import execution as gca_execution @@ -57,3 +59,24 @@ def _update_resource( cls, client: utils.MetadataClientWithOverride, resource: proto.Message, ) -> proto.Message: return client.update_execution(execution=resource) + + def add_artifact( + self, artifact_resource_name: str, input: bool, + ): + """Creates a new Metadata resource. + + Args: + artifact_resource_name (str): + Required. The full resource name of the Artifact to connect to the Execution through an Event. + input (bool) + Required. Whether Artifact is an input event to the Execution or not. + """ + + event = Event( + artifact=artifact_resource_name, + type_=Event.Type.INPUT if input else Event.Type.OUTPUT, + ) + + self.api_client.add_execution_events( + execution=self.resource_name, events=[event], + ) diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index 79640bf7e1..c1b76d04d0 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -17,11 +17,11 @@ from typing import Dict, Union -from google.cloud.aiplatform.metadata.metadata_store import _MetadataStore from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata.artifact import _Artifact from google.cloud.aiplatform.metadata.context import _Context from google.cloud.aiplatform.metadata.execution import _Execution -from google.cloud.aiplatform.metadata.artifact import _Artifact +from google.cloud.aiplatform.metadata.metadata_store import _MetadataStore class _MetadataService: @@ -30,15 +30,18 @@ class _MetadataService: def __init__(self): self._experiment = None self._run = None + self._metrics = None def set_experiment(self, experiment: str): _MetadataStore.get_or_create() context = _Context.get_or_create( resource_id=experiment, + display_name=experiment, schema_title=constants.SYSTEM_EXPERIMENT, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, ) - self._experiment = context.name + self._experiment = context def set_run(self, run: str): if not self._experiment: @@ -46,29 +49,46 @@ def set_run(self, run: str): "No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') " "before trying to set_run. " ) - execution = _Execution.get_or_create( - resource_id=run, + run_execution_id = f"{self._experiment.name}-{run}" + run_execution = _Execution.get_or_create( + resource_id=run_execution_id, + display_name=run, schema_title=constants.SYSTEM_RUN, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], ) - self._run = execution.name + self._experiment.add_artifacts_and_executions( + execution_resource_names=[run_execution.resource_name] + ) + + metrics_artifact_id = f"{self._experiment.name}-{run}-metrics" + metrics_artifact = _Artifact.get_or_create( + resource_id=metrics_artifact_id, + display_name=metrics_artifact_id, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + run_execution.add_artifact( + artifact_resource_name=metrics_artifact.resource_name, input=False + ) + + self._run = run_execution + self._metrics = metrics_artifact def log_params(self, params: Dict[str, Union[float, int, str]]): self._validate_experiment_and_run(method_name="log_params") + # query the latest run execution resource before logging. execution = _Execution.get_or_create( - resource_id=self._run, + resource_id=self._run.name, schema_title=constants.SYSTEM_RUN, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], ) execution.update(metadata=params) - self._run = execution.name def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): self._validate_experiment_and_run(method_name="log_metrics") - # Only one metrics artifact for the (experiment, run) tuple. - artifact_id = f"{self._experiment}-{self._run}" + # query the latest metrics artifact resource before logging. artifact = _Artifact.get_or_create( - resource_id=artifact_id, + resource_id=self._metrics.name, schema_title=constants.SYSTEM_METRICS, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], ) diff --git a/google/cloud/aiplatform/metadata/metadata_store.py b/google/cloud/aiplatform/metadata/metadata_store.py index 3c3664b1d5..2a55f066a8 100644 --- a/google/cloud/aiplatform/metadata/metadata_store.py +++ b/google/cloud/aiplatform/metadata/metadata_store.py @@ -14,11 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Optional import logging -from google.auth import credentials as auth_credentials +from typing import Optional + from google.api_core import exceptions +from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import compat diff --git a/google/cloud/aiplatform/metadata/resource.py b/google/cloud/aiplatform/metadata/resource.py index ab0a3a9aa4..022e51e7e8 100644 --- a/google/cloud/aiplatform/metadata/resource.py +++ b/google/cloud/aiplatform/metadata/resource.py @@ -14,18 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import abc -import re -import proto import logging -from typing import Optional, Dict +import re from copy import deepcopy +from typing import Optional, Dict +import proto from google.api_core import exceptions -from google.cloud.aiplatform import utils from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base, initializer +from google.cloud.aiplatform import utils class _Resource(base.AiPlatformResourceNounWithFutureManager, abc.ABC): @@ -177,7 +178,10 @@ def update( """ gca_resource = deepcopy(self._gca_resource) - gca_resource.metadata.update(metadata) + if gca_resource.metadata: + gca_resource.metadata.update(metadata) + else: + gca_resource.metadata = metadata api_client = self._instantiate_client(credentials=credentials) update_gca_resource = self._update_resource( diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py new file mode 100644 index 0000000000..cae78d85ee --- /dev/null +++ b/tests/unit/aiplatform/test_metadata.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +from unittest.mock import patch + +import pytest + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata import metadata +from google.cloud.aiplatform_v1beta1 import ( + AddContextArtifactsAndExecutionsResponse, + Event, +) +from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact +from google.cloud.aiplatform_v1beta1 import Context as GapicContext +from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution +from google.cloud.aiplatform_v1beta1 import ( + MetadataServiceClient, + AddExecutionEventsResponse, +) +from google.cloud.aiplatform_v1beta1 import MetadataStore as GapicMetadataStore + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" +) +_TEST_EXPERIMENT = "test-experiment" +_TEST_RUN = "run" + +# resource attributes +_TEST_METADATA = {"test-param1": 1, "test-param2": "test-value", "test-param3": True} + +# metadataStore +_TEST_METADATASTORE = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" +) + +# context +_TEST_CONTEXT_ID = _TEST_EXPERIMENT +_TEST_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_CONTEXT_ID}" + +# execution +_TEST_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}" +_TEST_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_EXECUTION_ID}" + +# artifact +_TEST_ARTIFACT_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}-metrics" +_TEST_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_ARTIFACT_ID}" + +# parameters +_TEST_PARAMS = {"learning_rate": 0.01, "dropout": 0.2} + +# metrics +_TEST_METRICS = {"rmse": 222, "accuracy": 1} + + +@pytest.fixture +def get_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + name=_TEST_METADATASTORE, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def get_context_mock(): + with patch.object(MetadataServiceClient, "get_context") as get_context_mock: + get_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + yield get_context_mock + + +@pytest.fixture +def add_context_artifacts_and_executions_mock(): + with patch.object( + MetadataServiceClient, "add_context_artifacts_and_executions" + ) as add_context_artifacts_and_executions_mock: + add_context_artifacts_and_executions_mock.return_value = ( + AddContextArtifactsAndExecutionsResponse() + ) + yield add_context_artifacts_and_executions_mock + + +@pytest.fixture +def get_execution_mock(): + with patch.object(MetadataServiceClient, "get_execution") as get_execution_mock: + get_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + yield get_execution_mock + + +@pytest.fixture +def update_execution_mock(): + with patch.object( + MetadataServiceClient, "update_execution" + ) as update_execution_mock: + update_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_PARAMS, + ) + yield update_execution_mock + + +@pytest.fixture +def add_execution_events_mock(): + with patch.object( + MetadataServiceClient, "add_execution_events" + ) as add_execution_events_mock: + add_execution_events_mock.return_value = AddExecutionEventsResponse() + yield add_execution_events_mock + + +@pytest.fixture +def get_artifact_mock(): + with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock: + get_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + yield get_artifact_mock + + +@pytest.fixture +def update_artifact_mock(): + with patch.object(MetadataServiceClient, "update_artifact") as update_artifact_mock: + update_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + metadata=_TEST_METRICS, + ) + yield update_artifact_mock + + +class TestMetadata: + def setup_method(self): + reload(initializer) + reload(aiplatform) + reload(metadata) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_experiment_with_existing_metadataStore_and_context( + self, get_metadata_store_mock, get_context_mock + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + + get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + def test_set_run_with_existing_execution_and_artifact( + self, + get_execution_mock, + add_context_artifacts_and_executions_mock, + get_artifact_mock, + add_execution_events_mock, + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.set_run(_TEST_RUN) + + get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=None, + executions=[_TEST_EXECUTION_NAME], + ) + get_artifact_mock.assert_called_once_with(name=_TEST_ARTIFACT_NAME) + add_execution_events_mock.assert_called_once_with( + execution=_TEST_EXECUTION_NAME, + events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], + ) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_log_params( + self, update_execution_mock, + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.set_run(_TEST_RUN) + aiplatform.log_params(_TEST_PARAMS) + + updated_execution = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_PARAMS, + ) + + update_execution_mock.assert_called_once_with(execution=updated_execution) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_log_metrics( + self, update_artifact_mock, + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.set_run(_TEST_RUN) + aiplatform.log_metrics(_TEST_METRICS) + + updated_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + metadata=_TEST_METRICS, + ) + + update_artifact_mock.assert_called_once_with(artifact=updated_artifact) diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py index dca0c569fa..a46860f37c 100644 --- a/tests/unit/aiplatform/test_metadata_resources.py +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -15,21 +15,26 @@ # limitations under the License. # -import pytest - from importlib import reload -from unittest.mock import patch +from unittest.mock import patch, call + +import pytest +from google.api_core import exceptions from google.cloud import aiplatform from google.cloud.aiplatform import initializer -from google.cloud.aiplatform.metadata import context from google.cloud.aiplatform.metadata import artifact +from google.cloud.aiplatform.metadata import context from google.cloud.aiplatform.metadata import execution - -from google.cloud.aiplatform_v1beta1 import MetadataServiceClient +from google.cloud.aiplatform_v1beta1 import AddContextArtifactsAndExecutionsResponse +from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact from google.cloud.aiplatform_v1beta1 import Context as GapicContext from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution -from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact +from google.cloud.aiplatform_v1beta1 import ( + MetadataServiceClient, + AddExecutionEventsResponse, + Event, +) # project _TEST_PROJECT = "test-project" @@ -44,6 +49,11 @@ _TEST_SCHEMA_VERSION = "0.0.1" _TEST_DESCRIPTION = "test description" _TEST_METADATA = {"test-param1": 1, "test-param2": "test-value", "test-param3": True} +_TEST_UPDATED_METADATA = { + "test-param1": 2, + "test-param2": "test-value-1", + "test-param3": False, +} # context _TEST_CONTEXT_ID = "test-context-id" @@ -72,6 +82,25 @@ def get_context_mock(): yield get_context_mock +@pytest.fixture +def get_context_for_get_or_create_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_context_for_get_or_create_mock: + get_context_for_get_or_create_mock.side_effect = [ + exceptions.NotFound("test: Context Not Found"), + GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield get_context_for_get_or_create_mock + + @pytest.fixture def create_context_mock(): with patch.object(MetadataServiceClient, "create_context") as create_context_mock: @@ -86,6 +115,31 @@ def create_context_mock(): yield create_context_mock +@pytest.fixture +def update_context_mock(): + with patch.object(MetadataServiceClient, "update_context") as update_context_mock: + update_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + yield update_context_mock + + +@pytest.fixture +def add_context_artifacts_and_executions_mock(): + with patch.object( + MetadataServiceClient, "add_context_artifacts_and_executions" + ) as add_context_artifacts_and_executions_mock: + add_context_artifacts_and_executions_mock.return_value = ( + AddContextArtifactsAndExecutionsResponse() + ) + yield add_context_artifacts_and_executions_mock + + @pytest.fixture def get_execution_mock(): with patch.object(MetadataServiceClient, "get_execution") as get_execution_mock: @@ -100,6 +154,25 @@ def get_execution_mock(): yield get_execution_mock +@pytest.fixture +def get_execution_for_get_or_create_mock(): + with patch.object( + MetadataServiceClient, "get_execution" + ) as get_execution_for_get_or_create_mock: + get_execution_for_get_or_create_mock.side_effect = [ + exceptions.NotFound("test: Execution Not Found"), + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield get_execution_for_get_or_create_mock + + @pytest.fixture def create_execution_mock(): with patch.object( @@ -116,6 +189,31 @@ def create_execution_mock(): yield create_execution_mock +@pytest.fixture +def update_execution_mock(): + with patch.object( + MetadataServiceClient, "update_execution" + ) as update_execution_mock: + update_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + yield update_execution_mock + + +@pytest.fixture +def add_execution_events_mock(): + with patch.object( + MetadataServiceClient, "add_execution_events" + ) as add_execution_events_mock: + add_execution_events_mock.return_value = AddExecutionEventsResponse() + yield add_execution_events_mock + + @pytest.fixture def get_artifact_mock(): with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock: @@ -130,6 +228,25 @@ def get_artifact_mock(): yield get_artifact_mock +@pytest.fixture +def get_artifact_for_get_or_create_mock(): + with patch.object( + MetadataServiceClient, "get_artifact" + ) as get_artifact_for_get_or_create_mock: + get_artifact_for_get_or_create_mock.side_effect = [ + exceptions.NotFound("test: Artifact Not Found"), + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield get_artifact_for_get_or_create_mock + + @pytest.fixture def create_artifact_mock(): with patch.object(MetadataServiceClient, "create_artifact") as create_artifact_mock: @@ -144,6 +261,20 @@ def create_artifact_mock(): yield create_artifact_mock +@pytest.fixture +def update_artifact_mock(): + with patch.object(MetadataServiceClient, "update_artifact") as update_artifact_mock: + update_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + yield update_artifact_mock + + class TestContext: def setup_method(self): reload(initializer) @@ -164,11 +295,12 @@ def test_init_context_with_id(self, get_context_mock): ) get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) - @pytest.mark.usefixtures("get_context_mock") - def test_create_context(self, create_context_mock): + def test_get_or_create_context( + self, get_context_for_get_or_create_mock, create_context_mock + ): aiplatform.init(project=_TEST_PROJECT) - my_context = context._Context._create( + my_context = context._Context.get_or_create( resource_id=_TEST_CONTEXT_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, @@ -185,7 +317,9 @@ def test_create_context(self, create_context_mock): description=_TEST_DESCRIPTION, metadata=_TEST_METADATA, ) - + get_context_for_get_or_create_mock.assert_has_calls( + calls=[call(name=_TEST_CONTEXT_NAME), call(name=_TEST_CONTEXT_NAME)] + ) create_context_mock.assert_called_once_with( parent=_TEST_PARENT, context_id=_TEST_CONTEXT_ID, context=expected_context, ) @@ -193,6 +327,104 @@ def test_create_context(self, create_context_mock): expected_context.name = _TEST_CONTEXT_NAME assert my_context._gca_resource == expected_context + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("create_context_mock") + def test_update_context(self, update_context_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context._Context._create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_context.update(_TEST_UPDATED_METADATA) + + updated_context = GapicContext( + name=_TEST_CONTEXT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + + update_context_mock.assert_called_once_with(context=updated_context,) + assert my_context._gca_resource == updated_context + + @pytest.mark.usefixtures("get_context_mock") + def test_add_artifacts_and_executions( + self, add_context_artifacts_and_executions_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_context = context._Context.get_or_create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + my_context.add_artifacts_and_executions( + artifact_resource_names=[_TEST_ARTIFACT_NAME], + execution_resource_names=[_TEST_EXECUTION_NAME], + ) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=[_TEST_ARTIFACT_NAME], + executions=[_TEST_EXECUTION_NAME], + ) + + @pytest.mark.usefixtures("get_context_mock") + def test_add_artifacts_only(self, add_context_artifacts_and_executions_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_context = context._Context.get_or_create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + my_context.add_artifacts_and_executions( + artifact_resource_names=[_TEST_ARTIFACT_NAME] + ) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=[_TEST_ARTIFACT_NAME], + executions=None, + ) + + @pytest.mark.usefixtures("get_context_mock") + def test_add_executions_only(self, add_context_artifacts_and_executions_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_context = context._Context.get_or_create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + my_context.add_artifacts_and_executions( + execution_resource_names=[_TEST_EXECUTION_NAME] + ) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=None, + executions=[_TEST_EXECUTION_NAME], + ) + class TestExecution: def setup_method(self): @@ -214,11 +446,12 @@ def test_init_execution_with_id(self, get_execution_mock): ) get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) - @pytest.mark.usefixtures("get_execution_mock") - def test_create_execution(self, create_execution_mock): + def test_get_or_create_execution( + self, get_execution_for_get_or_create_mock, create_execution_mock + ): aiplatform.init(project=_TEST_PROJECT) - my_execution = execution._Execution._create( + my_execution = execution._Execution.get_or_create( resource_id=_TEST_EXECUTION_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, @@ -235,7 +468,9 @@ def test_create_execution(self, create_execution_mock): description=_TEST_DESCRIPTION, metadata=_TEST_METADATA, ) - + get_execution_for_get_or_create_mock.assert_has_calls( + calls=[call(name=_TEST_EXECUTION_NAME), call(name=_TEST_EXECUTION_NAME)] + ) create_execution_mock.assert_called_once_with( parent=_TEST_PARENT, execution_id=_TEST_EXECUTION_ID, @@ -245,6 +480,55 @@ def test_create_execution(self, create_execution_mock): expected_execution.name = _TEST_EXECUTION_NAME assert my_execution._gca_resource == expected_execution + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("create_execution_mock") + def test_update_execution(self, update_execution_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_execution = execution._Execution._create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_execution.update(_TEST_UPDATED_METADATA) + + updated_execution = GapicExecution( + name=_TEST_EXECUTION_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + + update_execution_mock.assert_called_once_with(execution=updated_execution) + assert my_execution._gca_resource == updated_execution + + @pytest.mark.usefixtures("get_execution_mock") + def test_add_artifact(self, add_execution_events_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_execution = execution._Execution.get_or_create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_execution.add_artifact( + artifact_resource_name=_TEST_ARTIFACT_NAME, input=False, + ) + add_execution_events_mock.assert_called_once_with( + execution=_TEST_EXECUTION_NAME, + events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], + ) + class TestArtifact: def setup_method(self): @@ -266,11 +550,12 @@ def test_init_artifact_with_id(self, get_artifact_mock): ) get_artifact_mock.assert_called_once_with(name=_TEST_ARTIFACT_NAME) - @pytest.mark.usefixtures("get_artifact_mock") - def test_create_artifact(self, create_artifact_mock): + def test_get_or_create_artifact( + self, get_artifact_for_get_or_create_mock, create_artifact_mock + ): aiplatform.init(project=_TEST_PROJECT) - my_artifact = artifact._Artifact._create( + my_artifact = artifact._Artifact.get_or_create( resource_id=_TEST_ARTIFACT_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, @@ -287,7 +572,9 @@ def test_create_artifact(self, create_artifact_mock): description=_TEST_DESCRIPTION, metadata=_TEST_METADATA, ) - + get_artifact_for_get_or_create_mock.assert_has_calls( + calls=[call(name=_TEST_ARTIFACT_NAME), call(name=_TEST_ARTIFACT_NAME)] + ) create_artifact_mock.assert_called_once_with( parent=_TEST_PARENT, artifact_id=_TEST_ARTIFACT_ID, @@ -296,3 +583,31 @@ def test_create_artifact(self, create_artifact_mock): expected_artifact.name = _TEST_ARTIFACT_NAME assert my_artifact._gca_resource == expected_artifact + + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("create_artifact_mock") + def test_update_artifact(self, update_artifact_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_artifact = artifact._Artifact._create( + resource_id=_TEST_ARTIFACT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_artifact.update(_TEST_UPDATED_METADATA) + + updated_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + + update_artifact_mock.assert_called_once_with(artifact=updated_artifact) + assert my_artifact._gca_resource == updated_artifact diff --git a/tests/unit/aiplatform/test_metadata_store.py b/tests/unit/aiplatform/test_metadata_store.py index d8c38b8baf..516e61d849 100644 --- a/tests/unit/aiplatform/test_metadata_store.py +++ b/tests/unit/aiplatform/test_metadata_store.py @@ -16,25 +16,22 @@ # import os - -import pytest - -from unittest import mock from importlib import reload +from unittest import mock from unittest.mock import patch +import pytest from google.api_core import operation -from google.auth.exceptions import GoogleAuthError from google.auth import credentials as auth_credentials +from google.auth.exceptions import GoogleAuthError from google.cloud import aiplatform -from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform import initializer - +from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform_v1beta1 import MetadataServiceClient from google.cloud.aiplatform_v1beta1 import MetadataStore as GapicMetadataStore -from google.cloud.aiplatform_v1beta1.types import metadata_service from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1beta1.types import metadata_service # project _TEST_PROJECT = "test-project" From 763c1c9f0fb6a5978208dee879cda73907d5fae5 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Fri, 16 Apr 2021 16:28:09 -0400 Subject: [PATCH 11/36] test --- .kokoro/release.sh | 4 +- .kokoro/release/common.cfg | 14 +- .kokoro/samples/python3.8/common.cfg | 6 + CHANGELOG.md | 26 + docs/_static/custom.css | 13 +- docs/aiplatform.rst | 6 + docs/index.rst | 1 + google/cloud/aiplatform/__init__.py | 48 +- google/cloud/aiplatform/base.py | 1022 ++++ google/cloud/aiplatform/compat/__init__.py | 122 + .../aiplatform/compat/services/__init__.py | 79 + .../cloud/aiplatform/compat/types/__init__.py | 158 + google/cloud/aiplatform/constants.py | 62 + google/cloud/aiplatform/datasets/__init__.py | 31 + .../cloud/aiplatform/datasets/_datasources.py | 236 + google/cloud/aiplatform/datasets/dataset.py | 577 +++ .../aiplatform/datasets/image_dataset.py | 149 + .../aiplatform/datasets/tabular_dataset.py | 134 + .../cloud/aiplatform/datasets/text_dataset.py | 156 + .../aiplatform/datasets/video_dataset.py | 149 + google/cloud/aiplatform/explain/__init__.py | 59 + .../cloud/aiplatform/helpers/_decorators.py | 2 + google/cloud/aiplatform/initializer.py | 279 ++ google/cloud/aiplatform/jobs.py | 795 +++ google/cloud/aiplatform/models.py | 1997 ++++++++ google/cloud/aiplatform/schema.py | 76 + google/cloud/aiplatform/training_jobs.py | 4362 +++++++++++++++++ google/cloud/aiplatform/training_utils.py | 105 + google/cloud/aiplatform/utils.py | 469 ++ .../services/dataset_service/async_client.py | 20 +- .../dataset_service/transports/base.py | 20 +- .../services/endpoint_service/async_client.py | 14 +- .../endpoint_service/transports/base.py | 14 +- .../services/job_service/async_client.py | 40 +- .../services/job_service/transports/base.py | 40 +- .../services/migration_service/client.py | 12 +- .../services/model_service/async_client.py | 20 +- .../services/model_service/transports/base.py | 20 +- .../services/pipeline_service/async_client.py | 10 +- .../pipeline_service/transports/base.py | 10 +- .../prediction_service/async_client.py | 2 +- .../prediction_service/transports/base.py | 2 +- .../specialist_pool_service/async_client.py | 10 +- .../transports/base.py | 10 +- .../featurestore_service/async_client.py | 4 +- .../services/featurestore_service/client.py | 4 +- .../services/migration_service/client.py | 12 +- .../types/featurestore_service.py | 11 +- google/cloud/aiplatform_v1beta1/types/io.py | 4 +- samples/model-builder/conftest.py | 205 + .../create_and_import_dataset_image_sample.py | 44 + ...te_and_import_dataset_image_sample_test.py | 41 + .../create_and_import_dataset_text_sample.py | 44 + ...ate_and_import_dataset_text_sample_test.py | 39 + .../create_batch_prediction_job_sample.py | 49 + ...create_batch_prediction_job_sample_test.py | 42 + ...ng_pipeline_image_classification_sample.py | 57 + ...peline_image_classification_sample_test.py | 57 + ...text_classification_single_label_sample.py | 44 + ...classification_single_label_sample_test.py | 43 + ...port_data_text_entity_extraction_sample.py | 44 + ...data_text_entity_extraction_sample_test.py | 45 + ...ort_data_text_sentiment_analysis_sample.py | 44 + ...ata_text_sentiment_analysis_sample_test.py | 45 + samples/model-builder/init_sample.py | 40 + samples/model-builder/init_sample_test.py | 38 + samples/model-builder/noxfile.py | 221 + ...text_classification_single_label_sample.py | 33 + ...classification_single_label_sample_test.py | 37 + .../predict_text_entity_extraction_sample.py | 32 + ...dict_text_entity_extraction_sample_test.py | 35 + .../predict_text_sentiment_analysis_sample.py | 32 + ...ict_text_sentiment_analysis_sample_test.py | 35 + samples/model-builder/requirements-tests.txt | 1 + samples/model-builder/requirements.txt | 2 + samples/model-builder/test_constants.py | 53 + samples/snippets/requirements.txt | 2 +- setup.py | 3 +- tests/system/aiplatform/test_dataset.py | 287 ++ .../test_automl_image_training_jobs.py | 434 ++ .../test_automl_tabular_training_jobs.py | 441 ++ .../test_automl_text_training_jobs.py | 618 +++ .../test_automl_video_training_jobs.py | 463 ++ tests/unit/aiplatform/test_base.py | 201 + tests/unit/aiplatform/test_datasets.py | 1190 +++++ tests/unit/aiplatform/test_end_to_end.py | 462 ++ tests/unit/aiplatform/test_endpoints.py | 1079 ++++ tests/unit/aiplatform/test_initializer.py | 170 + tests/unit/aiplatform/test_jobs.py | 639 +++ tests/unit/aiplatform/test_models.py | 1130 +++++ tests/unit/aiplatform/test_training_jobs.py | 3865 +++++++++++++++ tests/unit/aiplatform/test_training_utils.py | 144 + tests/unit/aiplatform/test_utils.py | 305 ++ .../aiplatform_v1/test_migration_service.py | 36 +- .../test_migration_service.py | 24 +- 95 files changed, 24128 insertions(+), 183 deletions(-) create mode 100644 docs/aiplatform.rst create mode 100644 google/cloud/aiplatform/base.py create mode 100644 google/cloud/aiplatform/compat/__init__.py create mode 100644 google/cloud/aiplatform/compat/services/__init__.py create mode 100644 google/cloud/aiplatform/compat/types/__init__.py create mode 100644 google/cloud/aiplatform/constants.py create mode 100644 google/cloud/aiplatform/datasets/__init__.py create mode 100644 google/cloud/aiplatform/datasets/_datasources.py create mode 100644 google/cloud/aiplatform/datasets/dataset.py create mode 100644 google/cloud/aiplatform/datasets/image_dataset.py create mode 100644 google/cloud/aiplatform/datasets/tabular_dataset.py create mode 100644 google/cloud/aiplatform/datasets/text_dataset.py create mode 100644 google/cloud/aiplatform/datasets/video_dataset.py create mode 100644 google/cloud/aiplatform/explain/__init__.py create mode 100644 google/cloud/aiplatform/initializer.py create mode 100644 google/cloud/aiplatform/jobs.py create mode 100644 google/cloud/aiplatform/models.py create mode 100644 google/cloud/aiplatform/schema.py create mode 100644 google/cloud/aiplatform/training_jobs.py create mode 100644 google/cloud/aiplatform/training_utils.py create mode 100644 google/cloud/aiplatform/utils.py create mode 100644 samples/model-builder/conftest.py create mode 100644 samples/model-builder/create_and_import_dataset_image_sample.py create mode 100644 samples/model-builder/create_and_import_dataset_image_sample_test.py create mode 100644 samples/model-builder/create_and_import_dataset_text_sample.py create mode 100644 samples/model-builder/create_and_import_dataset_text_sample_test.py create mode 100644 samples/model-builder/create_batch_prediction_job_sample.py create mode 100644 samples/model-builder/create_batch_prediction_job_sample_test.py create mode 100644 samples/model-builder/create_training_pipeline_image_classification_sample.py create mode 100644 samples/model-builder/create_training_pipeline_image_classification_sample_test.py create mode 100644 samples/model-builder/import_data_text_classification_single_label_sample.py create mode 100644 samples/model-builder/import_data_text_classification_single_label_sample_test.py create mode 100644 samples/model-builder/import_data_text_entity_extraction_sample.py create mode 100644 samples/model-builder/import_data_text_entity_extraction_sample_test.py create mode 100644 samples/model-builder/import_data_text_sentiment_analysis_sample.py create mode 100644 samples/model-builder/import_data_text_sentiment_analysis_sample_test.py create mode 100644 samples/model-builder/init_sample.py create mode 100644 samples/model-builder/init_sample_test.py create mode 100644 samples/model-builder/noxfile.py create mode 100644 samples/model-builder/predict_text_classification_single_label_sample.py create mode 100644 samples/model-builder/predict_text_classification_single_label_sample_test.py create mode 100644 samples/model-builder/predict_text_entity_extraction_sample.py create mode 100644 samples/model-builder/predict_text_entity_extraction_sample_test.py create mode 100644 samples/model-builder/predict_text_sentiment_analysis_sample.py create mode 100644 samples/model-builder/predict_text_sentiment_analysis_sample_test.py create mode 100644 samples/model-builder/requirements-tests.txt create mode 100644 samples/model-builder/requirements.txt create mode 100644 samples/model-builder/test_constants.py create mode 100644 tests/system/aiplatform/test_dataset.py create mode 100644 tests/unit/aiplatform/test_automl_image_training_jobs.py create mode 100644 tests/unit/aiplatform/test_automl_tabular_training_jobs.py create mode 100644 tests/unit/aiplatform/test_automl_text_training_jobs.py create mode 100644 tests/unit/aiplatform/test_automl_video_training_jobs.py create mode 100644 tests/unit/aiplatform/test_base.py create mode 100644 tests/unit/aiplatform/test_datasets.py create mode 100644 tests/unit/aiplatform/test_end_to_end.py create mode 100644 tests/unit/aiplatform/test_endpoints.py create mode 100644 tests/unit/aiplatform/test_initializer.py create mode 100644 tests/unit/aiplatform/test_jobs.py create mode 100644 tests/unit/aiplatform/test_models.py create mode 100644 tests/unit/aiplatform/test_training_jobs.py create mode 100644 tests/unit/aiplatform/test_training_utils.py create mode 100644 tests/unit/aiplatform/test_utils.py diff --git a/.kokoro/release.sh b/.kokoro/release.sh index ab2a347901..62bdb892ff 100755 --- a/.kokoro/release.sh +++ b/.kokoro/release.sh @@ -26,7 +26,7 @@ python3 -m pip install --upgrade twine wheel setuptools export PYTHONUNBUFFERED=1 # Move into the package, build the distribution and upload. -TWINE_PASSWORD=$(cat "${KOKORO_KEYSTORE_DIR}/73713_google_cloud_pypi_password") +TWINE_PASSWORD=$(cat "${KOKORO_GFILE_DIR}/secret_manager/google-cloud-pypi-token") cd github/python-aiplatform python3 setup.py sdist bdist_wheel -twine upload --username gcloudpypi --password "${TWINE_PASSWORD}" dist/* +twine upload --username __token__ --password "${TWINE_PASSWORD}" dist/* diff --git a/.kokoro/release/common.cfg b/.kokoro/release/common.cfg index ff589f8e66..5293e75110 100644 --- a/.kokoro/release/common.cfg +++ b/.kokoro/release/common.cfg @@ -23,18 +23,8 @@ env_vars: { value: "github/python-aiplatform/.kokoro/release.sh" } -# Fetch PyPI password -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73713 - keyname: "google_cloud_pypi_password" - } - } -} - # Tokens needed to report release status back to GitHub env_vars: { key: "SECRET_MANAGER_KEYS" - value: "releasetool-publish-reporter-app,releasetool-publish-reporter-googleapis-installation,releasetool-publish-reporter-pem" -} \ No newline at end of file + value: "releasetool-publish-reporter-app,releasetool-publish-reporter-googleapis-installation,releasetool-publish-reporter-pem,google-cloud-pypi-token" +} diff --git a/.kokoro/samples/python3.8/common.cfg b/.kokoro/samples/python3.8/common.cfg index 512c9ee399..21b411c8e1 100644 --- a/.kokoro/samples/python3.8/common.cfg +++ b/.kokoro/samples/python3.8/common.cfg @@ -19,6 +19,12 @@ env_vars: { value: "py-3.8" } +# Run tests located under tests/system +env_vars: { + key: "RUN_SYSTEM_TESTS" + value: "true" +} + env_vars: { key: "TRAMPOLINE_BUILD_FILE" value: "github/python-aiplatform/.kokoro/test-samples.sh" diff --git a/CHANGELOG.md b/CHANGELOG.md index be2d9a602f..2c6398b03b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,31 @@ # Changelog +### [0.7.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.7.0...v0.7.1) (2021-04-14) + + +### Bug Fixes + +* fix list failing without order_by and local sorting ([#320](https://www.github.com/googleapis/python-aiplatform/issues/320)) ([06e99db](https://www.github.com/googleapis/python-aiplatform/commit/06e99db849d954344aeb8bdefde41d1884e36315)) + +## [0.7.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.6.0...v0.7.0) (2021-04-14) + + +### Features + +* Add Custom Container Prediction support, move to single API endpoint ([#277](https://www.github.com/googleapis/python-aiplatform/issues/277)) ([ca7f6d6](https://www.github.com/googleapis/python-aiplatform/commit/ca7f6d64ea75349a841b53fe6ef6547942439e35)) +* Add initial Model Builder SDK samples ([#265](https://www.github.com/googleapis/python-aiplatform/issues/265)) ([1230dc6](https://www.github.com/googleapis/python-aiplatform/commit/1230dc68a34c5b747186d31a25d1b8f40bf7a97e)) +* Add list() method to all resource nouns ([#294](https://www.github.com/googleapis/python-aiplatform/issues/294)) ([3ec9386](https://www.github.com/googleapis/python-aiplatform/commit/3ec9386f8f766662c91922af66b8098ddfa1eb8f)) +* add support for multiple client versions, change aiplatform from compat.V1BETA1 to compat.V1 ([#290](https://www.github.com/googleapis/python-aiplatform/issues/290)) ([89e3212](https://www.github.com/googleapis/python-aiplatform/commit/89e321246b6223a2355947d8dbd0161b84523478)) +* Make aiplatform.Dataset private ([#296](https://www.github.com/googleapis/python-aiplatform/issues/296)) ([1f0d5f3](https://www.github.com/googleapis/python-aiplatform/commit/1f0d5f3e3f95ee5056545e9d4742b96e9380a22e)) +* parse project location when passed full resource name to get apis ([#297](https://www.github.com/googleapis/python-aiplatform/issues/297)) ([674227d](https://www.github.com/googleapis/python-aiplatform/commit/674227d2e7ed4a4a4e180213dc1178dde7d65a3a)) + + +### Bug Fixes + +* add quotes to logged snippet ([0ecd0a8](https://www.github.com/googleapis/python-aiplatform/commit/0ecd0a8bbc5a2fc645877d0eb3b930e1b03a270a)) +* make logging more informative during training ([#310](https://www.github.com/googleapis/python-aiplatform/issues/310)) ([9a4d991](https://www.github.com/googleapis/python-aiplatform/commit/9a4d99150a035b8dde7b4f9e72f25745af17b609)) +* remove TPU from accelerator test cases ([57f4fcf](https://www.github.com/googleapis/python-aiplatform/commit/57f4fcf7637467f6176436f6d2e1f6c8be909c4a)) + ## [0.6.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.5.1...v0.6.0) (2021-03-22) diff --git a/docs/_static/custom.css b/docs/_static/custom.css index bcd37bbd3c..b0a295464b 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -1,9 +1,20 @@ div#python2-eol { border-color: red; border-width: medium; -} +} /* Ensure minimum width for 'Parameters' / 'Returns' column */ dl.field-list > dt { min-width: 100px } + +/* Insert space between methods for readability */ +dl.method { + padding-top: 10px; + padding-bottom: 10px +} + +/* Insert empty space between classes */ +dl.class { + padding-bottom: 50px +} diff --git a/docs/aiplatform.rst b/docs/aiplatform.rst new file mode 100644 index 0000000000..bf5cd4625b --- /dev/null +++ b/docs/aiplatform.rst @@ -0,0 +1,6 @@ +Google Cloud Aiplatform SDK +============================================= + +.. automodule:: google.cloud.aiplatform + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 765eb55989..031271a261 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,6 +7,7 @@ API Reference .. toctree:: :maxdepth: 2 + aiplatform aiplatform_v1/services aiplatform_v1/types diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index bb196e2c19..58eb824454 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -16,6 +16,52 @@ # from google.cloud.aiplatform import gapic +from google.cloud.aiplatform import explain +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.datasets import ( + ImageDataset, + TabularDataset, + TextDataset, + VideoDataset, +) +from google.cloud.aiplatform.models import Endpoint +from google.cloud.aiplatform.models import Model +from google.cloud.aiplatform.jobs import BatchPredictionJob +from google.cloud.aiplatform.training_jobs import ( + CustomTrainingJob, + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + AutoMLTabularTrainingJob, + AutoMLImageTrainingJob, + AutoMLTextTrainingJob, + AutoMLVideoTrainingJob, +) -__all__ = ("gapic",) +""" +Usage: +from google.cloud import aiplatform + +aiplatform.init(project='my_project') +""" +init = initializer.global_config.init + +__all__ = ( + "explain", + "gapic", + "init", + "AutoMLImageTrainingJob", + "AutoMLTabularTrainingJob", + "AutoMLTextTrainingJob", + "AutoMLVideoTrainingJob", + "BatchPredictionJob", + "CustomTrainingJob", + "CustomContainerTrainingJob", + "CustomPythonPackageTrainingJob", + "Endpoint", + "ImageDataset", + "Model", + "TabularDataset", + "TextDataset", + "VideoDataset", +) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py new file mode 100644 index 0000000000..907397b7e8 --- /dev/null +++ b/google/cloud/aiplatform/base.py @@ -0,0 +1,1022 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +from concurrent import futures +import datetime +import functools +import inspect +import logging +import sys +import threading +from typing import ( + Any, + Callable, + Dict, + List, + Iterable, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import proto + +from google.api_core import operation +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils + + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) + + +class Logger: + """Logging wrapper class with high level helper methods.""" + + def __init__(self, name: str = ""): + """Initializes logger with name. + + Args: + name (str): Name to associate with logger. + """ + self._logger = logging.getLogger(name) + + def log_create_with_lro( + self, + cls: Type["AiPlatformResourceNoun"], + lro: Optional[operation.Operation] = None, + ): + """Logs create event with LRO. + + Args: + cls (AiPlatformResourceNoune): + AI Platform Resource Noun class that is being created. + lro (operation.Operation): + Optional. Backing LRO for creation. + """ + self._logger.info(f"Creating {cls.__name__}") + + if lro: + self._logger.info( + f"Create {cls.__name__} backing LRO: {lro.operation.name}" + ) + + def log_create_complete( + self, + cls: Type["AiPlatformResourceNoun"], + resource: proto.Message, + variable_name: str, + ): + """Logs create event is complete. + + Will also include code snippet to instantiate resource in SDK. + + Args: + cls (AiPlatformResourceNoun): + AI Platform Resource Noun class that is being created. + resource (proto.Message): + AI Platform Resourc proto.Message + variable_name (str): Name of variable to use for code snippet + + """ + self._logger.info(f"{cls.__name__} created. Resource name: {resource.name}") + self._logger.info(f"To use this {cls.__name__} in another session:") + self._logger.info( + f"{variable_name} = aiplatform.{cls.__name__}('{resource.name}')" + ) + + def log_action_start_against_resource( + self, action: str, noun: str, resource_noun_obj: "AiPlatformResourceNoun" + ): + """Logs intention to start an action against a resource. + + Args: + action (str): Action to complete against the resource ie: "Deploying". Can be empty string. + noun (str): Noun the action acts on against the resource. Can be empty string. + resource_noun_obj (AiPlatformResourceNoun): + Resource noun object the action is acting against. + """ + self._logger.info( + f"{action} {resource_noun_obj.__class__.__name__} {noun}: {resource_noun_obj.resource_name}" + ) + + def log_action_started_against_resource_with_lro( + self, + action: str, + noun: str, + cls: Type["AiPlatformResourceNoun"], + lro: operation.Operation, + ): + """Logs an action started against a resource with lro. + + Args: + action (str): Action started against resource. ie: "Deploy". Can be empty string. + noun (str): Noun the action acts on against the resource. Can be empty string. + cls (AiPlatformResourceNoun): + Resource noun object the action is acting against. + lro (operation.Operation): Backing LRO for action. + """ + self._logger.info( + f"{action} {cls.__name__} {noun} backing LRO: {lro.operation.name}" + ) + + def log_action_completed_against_resource( + self, noun: str, action: str, resource_noun_obj: "AiPlatformResourceNoun" + ): + """Logs action completed against resource. + + Args: + noun (str): Noun the action acts on against the resource. Can be empty string. + action (str): Action started against resource. ie: "Deployed". Can be empty string. + resource_noun_obj (AiPlatformResourceNoun): + Resource noun object the action is acting against + """ + self._logger.info( + f"{resource_noun_obj.__class__.__name__} {noun} {action}. Resource name: {resource_noun_obj.resource_name}" + ) + + def __getattr__(self, attr: str): + """Forward remainder of logging to underlying logger.""" + return getattr(self._logger, attr) + + +_LOGGER = Logger(__name__) + + +class FutureManager(metaclass=abc.ABCMeta): + """Tracks concurrent futures against this object.""" + + def __init__(self): + self.__latest_future_lock = threading.Lock() + + # Always points to the latest future. All submitted futures will always + # form a dependency on the latest future. + self.__latest_future = None + + # Caches Exception of any executed future. Once one exception occurs + # all additional futures should fail and any additional invocations will block. + self._exception = None + + def _raise_future_exception(self): + """Raises exception if one of the object's futures has raised.""" + with self.__latest_future_lock: + if self._exception: + raise self._exception + + def _complete_future(self, future: futures.Future): + """Checks for exception of future and removes the pointer if it's still latest. + + Args: + future (futures.Future): Required. A future to complete. + """ + + with self.__latest_future_lock: + try: + future.result() # raises + except Exception as e: + self._exception = e + + if self.__latest_future is future: + self.__latest_future = None + + def _are_futures_done(self) -> bool: + """Helper method to check to all futures are complete. + + Returns: + True if no latest future. + """ + with self.__latest_future_lock: + return self.__latest_future is None + + def wait(self): + """Helper method to that blocks until all futures are complete.""" + future = self.__latest_future + if future: + futures.wait([future], return_when=futures.FIRST_EXCEPTION) + + self._raise_future_exception() + + @property + def _latest_future(self) -> Optional[futures.Future]: + """Get the latest future if it exists""" + with self.__latest_future_lock: + return self.__latest_future + + @_latest_future.setter + def _latest_future(self, future: Optional[futures.Future]): + """Optionally set the latest future and add a complete_future callback.""" + with self.__latest_future_lock: + self.__latest_future = future + if future: + future.add_done_callback(self._complete_future) + + def _submit( + self, + method: Callable[..., Any], + args: Sequence[Any], + kwargs: Dict[str, Any], + additional_dependencies: Optional[Sequence[futures.Future]] = None, + callbacks: Optional[Sequence[Callable[[futures.Future], Any]]] = None, + internal_callbacks: Iterable[Callable[[Any], Any]] = None, + ) -> futures.Future: + """Submit a method as a future against this object. + + Args: + method (Callable): Required. The method to submit. + args (Sequence): Required. The arguments to call the method with. + kwargs (dict): Required. The keyword arguments to call the method with. + additional_dependencies (Optional[Sequence[futures.Future]]): + Optional. Additional dependent futures to wait on before executing + method. Note: No validation is done on the dependencies. + callbacks (Optional[Sequence[Callable[[futures.Future], Any]]]): + Optional. Additional Future callbacks to execute once this created + Future is complete. + + Returns: + future (Future): Future of the submitted method call. + """ + + def wait_for_dependencies_and_invoke( + deps: Sequence[futures.Future], + method: Callable[..., Any], + args: Sequence[Any], + kwargs: Dict[str, Any], + internal_callbacks: Iterable[Callable[[Any], Any]], + ) -> Any: + """Wrapper method to wait on any dependencies before submitting method. + + Args: + deps (Sequence[futures.Future]): + Required. Dependent futures to wait on before executing method. + Note: No validation is done on the dependencies. + method (Callable): Required. The method to submit. + args (Sequence[Any]): Required. The arguments to call the method with. + kwargs (Dict[str, Any]): + Required. The keyword arguments to call the method with. + internal_callbacks: (Callable[[Any], Any]): + Callbacks that take the result of method. + + """ + + for future in set(deps): + future.result() + + result = method(*args, **kwargs) + + # call callbacks from within future + if internal_callbacks: + for callback in internal_callbacks: + callback(result) + + return result + + # Retrieves any dependencies from arguments. + deps = [ + arg._latest_future + for arg in list(args) + list(kwargs.values()) + if isinstance(arg, FutureManager) + ] + + # Retrieves exceptions and raises + # if any upstream dependency has an exception + exceptions = [ + arg._exception + for arg in list(args) + list(kwargs.values()) + if isinstance(arg, FutureManager) and arg._exception + ] + + if exceptions: + raise exceptions[0] + + # filter out objects that do not have pending tasks + deps = [dep for dep in deps if dep] + + if additional_dependencies: + deps.extend(additional_dependencies) + + with self.__latest_future_lock: + + # form a dependency on the latest future of this object + if self.__latest_future: + deps.append(self.__latest_future) + + self.__latest_future = initializer.global_pool.submit( + wait_for_dependencies_and_invoke, + deps=deps, + method=method, + args=args, + kwargs=kwargs, + internal_callbacks=internal_callbacks, + ) + + future = self.__latest_future + + # Clean up callback captures exception as well as removes future. + # May execute immediately and take lock. + + future.add_done_callback(self._complete_future) + + if callbacks: + for c in callbacks: + future.add_done_callback(c) + + return future + + @classmethod + @abc.abstractmethod + def _empty_constructor(cls) -> "FutureManager": + """Should construct object with all non FutureManager attributes as None""" + pass + + @abc.abstractmethod + def _sync_object_with_future_result(self, result: "FutureManager"): + """Should sync the object from _empty_constructor with result of future.""" + + def __repr__(self) -> str: + if self._exception: + return f"{object.__repr__(self)} failed with {str(self._exception)}" + + if self.__latest_future: + return f"{object.__repr__(self)} is waiting for upstream dependencies to complete." + + return object.__repr__(self) + + +class AiPlatformResourceNoun(metaclass=abc.ABCMeta): + """Base class the AI Platform resource nouns. + + Subclasses require two class attributes: + + client_class: The client to instantiate to interact with this resource noun. + _is_client_prediction_client: Flag to indicate if the client requires a prediction endpoint. + + Subclass is required to populate private attribute _gca_resource which is the + service representation of the resource noun. + """ + + @property + @classmethod + @abc.abstractmethod + def client_class(cls) -> Type[utils.AiPlatformServiceClientWithOverride]: + """Client class required to interact with resource with optional overrides.""" + pass + + @property + @classmethod + @abc.abstractmethod + def _is_client_prediction_client(cls) -> bool: + """Flag to indicate whether to use prediction endpoint with client.""" + pass + + @property + @abc.abstractmethod + def _getter_method(cls) -> str: + """Name of getter method of client class for retrieving the resource.""" + pass + + @property + @abc.abstractmethod + def _delete_method(cls) -> str: + """Name of delete method of client class for deleting the resource.""" + pass + + @property + @abc.abstractmethod + def _resource_noun(cls) -> str: + """Resource noun""" + pass + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, + ): + """Initializes class with project, location, and api_client. + + Args: + project(str): Project of the resource noun. + location(str): The location of the resource noun. + credentials(google.auth.crendentials.Crendentials): Optional custom + credentials to use when accessing interacting with resource noun. + resource_name(str): A fully-qualified resource name or ID. + """ + + if resource_name: + project, location = self._get_and_validate_project_location( + resource_name=resource_name, project=project, location=location + ) + + self.project = project or initializer.global_config.project + self.location = location or initializer.global_config.location + self.credentials = credentials or initializer.global_config.credentials + + self.api_client = self._instantiate_client(self.location, self.credentials) + + @classmethod + def _instantiate_client( + cls, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> utils.AiPlatformServiceClientWithOverride: + """Helper method to instantiate service client for resource noun. + + Args: + location (str): The location of the resource noun. + credentials (google.auth.credentials.Credentials): + Optional custom credentials to use when accessing interacting with + resource noun. + Returns: + client (utils.AiPlatformServiceClientWithOverride): + Initialized service client for this service noun with optional overrides. + """ + return initializer.global_config.create_client( + client_class=cls.client_class, + credentials=credentials, + location_override=location, + prediction_client=cls._is_client_prediction_client, + ) + + def _get_and_validate_project_location( + self, + resource_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + ) -> Tuple: + + """Validate the project and location for the resource. + + Args: + resource_name(str): Required. A fully-qualified resource name or ID. + project(str): Project of the resource noun. + location(str): The location of the resource noun. + + Raises: + RuntimeError if location is different from resource location + """ + + if not project and not location: + return project, location + + fields = utils.extract_fields_from_resource_name( + resource_name, self._resource_noun + ) + if not fields: + return project, location + + if location and fields.location != location: + raise RuntimeError( + f"location {location} is provided, but different from " + f"the resource location {fields.location}" + ) + + return fields.project, fields.location + + def _get_gca_resource(self, resource_name: str) -> proto.Message: + """Returns GAPIC service representation of client class resource.""" + """ + Args: + resource_name (str): + Required. A fully-qualified resource name or ID. + """ + + resource_name = utils.full_resource_name( + resource_name=resource_name, + resource_noun=self._resource_noun, + project=self.project, + location=self.location, + ) + + return getattr(self.api_client, self._getter_method)(name=resource_name) + + def _sync_gca_resource(self): + """Sync GAPIC service representation of client class resource.""" + + self._gca_resource = self._get_gca_resource(resource_name=self.resource_name) + + @property + def name(self) -> str: + """Name of this resource.""" + return self._gca_resource.name.split("/")[-1] + + @property + def resource_name(self) -> str: + """Full qualified resource name.""" + return self._gca_resource.name + + @property + def display_name(self) -> str: + """Display name of this resource.""" + return self._gca_resource.display_name + + @property + def create_time(self) -> datetime.datetime: + """Time this resource was created.""" + return self._gca_resource.create_time + + @property + def update_time(self) -> datetime.datetime: + """Time this resource was last updated.""" + self._sync_gca_resource() + return self._gca_resource.update_time + + def __repr__(self) -> str: + return f"{object.__repr__(self)} \nresource name: {self.resource_name}" + + +def optional_sync( + construct_object_on_arg: Optional[str] = None, + return_input_arg: Optional[str] = None, + bind_future_to_self: bool = True, +): + """Decorator for AiPlatformResourceNounWithFutureManager with optional sync support. + + Methods with this decorator should include a "sync" argument that defaults to + True. If called with sync=False this decorator will launch the method as a + concurrent Future in a separate Thread. + + Note that this is only robust enough to support our current end to end patterns + and may not be suitable for new patterns. + + Args: + construct_object_on_arg (str): + Optional. If provided, will only construct output object if arg is present. + Example: If custom training does not produce a model. + return_input_arg (str): + Optional. If provided will return passed in argument instead of + constructing. + Example: Model.deploy(Endpoint) returns the passed in Endpoint + bind_future_to_self (bool): + Whether to add this future to the calling object. + Example: Model.deploy(Endpoint) would be set to False because we only + want the deployment Future to be associated with Endpoint. + """ + + def optional_run_in_thread(method: Callable[..., Any]): + """Optionally run this method concurrently in separate Thread. + + Args: + method (Callable[..., Any]): Method to optionally run in separate Thread. + """ + + @functools.wraps(method) + def wrapper(*args, **kwargs): + """Wraps method.""" + sync = kwargs.pop("sync", True) + bound_args = inspect.signature(method).bind(*args, **kwargs) + self = bound_args.arguments.get("self") + calling_object_latest_future = None + + # check to see if this object has any exceptions + if self: + calling_object_latest_future = self._latest_future + self._raise_future_exception() + + # if sync then wait for any Futures to complete and execute + if sync: + if self: + self.wait() + return method(*args, **kwargs) + + # callbacks to call within the Future (in same Thread) + internal_callbacks = [] + # callbacks to add to the Future (may or may not be in same Thread) + callbacks = [] + # additional Future dependencies to capture + dependencies = [] + + # all methods should have type signatures + return_type = get_annotation_class( + inspect.getfullargspec(method).annotations["return"] + ) + + # is a classmethod that creates the object and returns it + if args and inspect.isclass(args[0]): + # assumes classmethod is our resource noun + returned_object = args[0]._empty_constructor() + self = returned_object + + else: # instance method + + # object produced by the method + returned_object = bound_args.arguments.get(return_input_arg) + + # if we're returning an input object + if returned_object and returned_object is not self: + + # make sure the input object doesn't have any exceptions + # from previous futures + returned_object._raise_future_exception() + + # if the future will be associated with both the returned object + # and calling object then we need to add additional callback + # to remove the future from the returned object + + # if we need to construct a new empty returned object + should_construct = not returned_object and bound_args.arguments.get( + construct_object_on_arg, not construct_object_on_arg + ) + + if should_construct: + if return_type is not None: + returned_object = return_type._empty_constructor() + + # if the future will be associated with both the returned object + # and calling object then we need to add additional callback + # to remove the future from the returned object + if returned_object and bind_future_to_self: + callbacks.append(returned_object._complete_future) + + if returned_object: + # sync objects after future completes + internal_callbacks.append( + returned_object._sync_object_with_future_result + ) + + # If the future is not associated with the calling object + # then the return object future needs to form a dependency on the + # the latest future in the calling object. + if not bind_future_to_self: + if calling_object_latest_future: + dependencies.append(calling_object_latest_future) + self = returned_object + + future = self._submit( + method=method, + callbacks=callbacks, + internal_callbacks=internal_callbacks, + additional_dependencies=dependencies, + args=[], + kwargs=bound_args.arguments, + ) + + # if the calling object is the one that submitted then add it's future + # to the returned object + if returned_object and returned_object is not self: + returned_object._latest_future = future + + return returned_object + + return wrapper + + return optional_run_in_thread + + +class AiPlatformResourceNounWithFutureManager(AiPlatformResourceNoun, FutureManager): + """Allows optional asynchronous calls to this AI Platform Resource Nouns.""" + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, + ): + """Initializes class with project, location, and api_client. + + Args: + project (str): Optional. Project of the resource noun. + location (str): Optional. The location of the resource noun. + credentials(google.auth.crendentials.Crendentials): + Optional. custom credentials to use when accessing interacting with + resource noun. + resource_name(str): A fully-qualified resource name or ID. + """ + AiPlatformResourceNoun.__init__( + self, + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, + ) + FutureManager.__init__(self) + + @classmethod + def _empty_constructor( + cls, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, + ) -> "AiPlatformResourceNounWithFutureManager": + """Initializes with all attributes set to None. + + The attributes should be populated after a future is complete. This allows + scheduling of additional API calls before the resource is created. + + Args: + project (str): Optional. Project of the resource noun. + location (str): Optional. The location of the resource noun. + credentials(google.auth.crendentials.Crendentials): + Optional. custom credentials to use when accessing interacting with + resource noun. + resource_name(str): A fully-qualified resource name or ID. + Returns: + An instance of this class with attributes set to None. + """ + self = cls.__new__(cls) + AiPlatformResourceNoun.__init__( + self, + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, + ) + FutureManager.__init__(self) + self._gca_resource = None + return self + + def _sync_object_with_future_result( + self, result: "AiPlatformResourceNounWithFutureManager" + ): + """Populates attributes from a Future result to this object. + + Args: + result: AiPlatformResourceNounWithFutureManager + Required. Result of future with same type as this object. + """ + sync_attributes = [ + "project", + "location", + "api_client", + "_gca_resource", + "credentials", + ] + optional_sync_attributes = ["_prediction_client"] + + for attribute in sync_attributes: + setattr(self, attribute, getattr(result, attribute)) + + for attribute in optional_sync_attributes: + value = getattr(result, attribute, None) + if value: + setattr(self, attribute, value) + + def _construct_sdk_resource_from_gapic( + self, + gapic_resource: proto.Message, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> AiPlatformResourceNoun: + """Given a GAPIC resource object, return the SDK representation. + + Args: + gapic_resource (proto.Message): + A GAPIC representation of an AI Platform resource, usually + retrieved by a get_* or in a list_* API call. + project (str): + Optional. Project to construct SDK object from. If not set, + project set in aiplatform.init will be used. + location (str): + Optional. Location to construct SDK object from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to construct SDK object. + Overrides credentials set in aiplatform.init. + + Returns: + AiPlatformResourceNoun: + An initialized SDK object that represents GAPIC type. + """ + sdk_resource = self._empty_constructor( + project=project, location=location, credentials=credentials + ) + sdk_resource._gca_resource = gapic_resource + return sdk_resource + + # TODO(b/144545165): Improve documentation for list filtering once available + # TODO(b/184910159): Expose `page_size` field in list method + @classmethod + def _list( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + self = cls._empty_constructor( + project=project, location=location, credentials=credentials + ) + + # Fetch credentials once and re-use for all `_empty_constructor()` calls + creds = initializer.global_config.credentials + + resource_list_method = getattr(self.api_client, self._list_method) + + list_request = { + "parent": initializer.global_config.common_location_path( + project=project, location=location + ), + "filter": filter, + } + + if order_by: + list_request["order_by"] = order_by + + resource_list = resource_list_method(request=list_request) or [] + + return [ + self._construct_sdk_resource_from_gapic( + gapic_resource, project=project, location=location, credentials=creds + ) + for gapic_resource in resource_list + if cls_filter(gapic_resource) + ] + + @classmethod + def _list_with_local_order( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + Provides client-side sorting when a list API doesn't support `order_by`. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + + li = cls._list( + cls_filter=cls_filter, + filter=filter, + order_by=None, # This method will handle the ordering locally + project=project, + location=location, + credentials=credentials, + ) + + if order_by: + desc = "desc" in order_by + order_by = order_by.replace("desc", "") + order_by = order_by.split(",") + + li.sort( + key=lambda x: tuple(getattr(x, field.strip()) for field in order_by), + reverse=desc, + ) + + return li + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """List all instances of this AI Platform Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + aiplatform.Model.list(order_by="create_time desc, display_name") + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + + return cls._list( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + @optional_sync() + def delete(self, sync: bool = True) -> None: + """Deletes this AI Platform resource. WARNING: This deletion is permament. + + Args: + sync (bool): + Whether to execute this deletion synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + _LOGGER.log_action_start_against_resource("Deleting", "", self) + lro = getattr(self.api_client, self._delete_method)(name=self.resource_name) + _LOGGER.log_action_started_against_resource_with_lro( + "Delete", "", self.__class__, lro + ) + lro.result() + _LOGGER.log_action_completed_against_resource("deleted.", "", self) + + def __repr__(self) -> str: + if self._gca_resource: + return AiPlatformResourceNoun.__repr__(self) + + return FutureManager.__repr__(self) + + +def get_annotation_class(annotation: type) -> type: + """Helper method to retrieve type annotation. + + Args: + annotation (type): Type hint + """ + # typing.Optional + if getattr(annotation, "__origin__", None) is Union: + return annotation.__args__[0] + else: + return annotation diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py new file mode 100644 index 0000000000..36d805c6cb --- /dev/null +++ b/google/cloud/aiplatform/compat/__init__.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat import services +from google.cloud.aiplatform.compat import types + +V1BETA1 = "v1beta1" +V1 = "v1" + +DEFAULT_VERSION = V1 + +if DEFAULT_VERSION == V1BETA1: + + services.dataset_service_client = services.dataset_service_client_v1beta1 + services.endpoint_service_client = services.endpoint_service_client_v1beta1 + services.job_service_client = services.job_service_client_v1beta1 + services.model_service_client = services.model_service_client_v1beta1 + services.pipeline_service_client = services.pipeline_service_client_v1beta1 + services.prediction_service_client = services.prediction_service_client_v1beta1 + services.specialist_pool_service_client = ( + services.specialist_pool_service_client_v1beta1 + ) + + types.accelerator_type = types.accelerator_type_v1beta1 + types.annotation = types.annotation_v1beta1 + types.annotation_spec = types.annotation_spec_v1beta1 + types.batch_prediction_job = types.batch_prediction_job_v1beta1 + types.completion_stats = types.completion_stats_v1beta1 + types.custom_job = types.custom_job_v1beta1 + types.data_item = types.data_item_v1beta1 + types.data_labeling_job = types.data_labeling_job_v1beta1 + types.dataset = types.dataset_v1beta1 + types.dataset_service = types.dataset_service_v1beta1 + types.deployed_model_ref = types.deployed_model_ref_v1beta1 + types.encryption_spec = types.encryption_spec_v1beta1 + types.endpoint = types.endpoint_v1beta1 + types.endpoint_service = types.endpoint_service_v1beta1 + types.env_var = types.env_var_v1beta1 + types.explanation = types.explanation_v1beta1 + types.explanation_metadata = types.explanation_metadata_v1beta1 + types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1 + types.io = types.io_v1beta1 + types.job_service = types.job_service_v1beta1 + types.job_state = types.job_state_v1beta1 + types.machine_resources = types.machine_resources_v1beta1 + types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1beta1 + types.model = types.model_v1beta1 + types.model_evaluation = types.model_evaluation_v1beta1 + types.model_evaluation_slice = types.model_evaluation_slice_v1beta1 + types.model_service = types.model_service_v1beta1 + types.operation = types.operation_v1beta1 + types.pipeline_service = types.pipeline_service_v1beta1 + types.pipeline_state = types.pipeline_state_v1beta1 + types.prediction_service = types.prediction_service_v1beta1 + types.specialist_pool = types.specialist_pool_v1beta1 + types.specialist_pool_service = types.specialist_pool_service_v1beta1 + types.training_pipeline = types.training_pipeline_v1beta1 + +if DEFAULT_VERSION == V1: + + services.dataset_service_client = services.dataset_service_client_v1 + services.endpoint_service_client = services.endpoint_service_client_v1 + services.job_service_client = services.job_service_client_v1 + services.model_service_client = services.model_service_client_v1 + services.pipeline_service_client = services.pipeline_service_client_v1 + services.prediction_service_client = services.prediction_service_client_v1 + services.specialist_pool_service_client = services.specialist_pool_service_client_v1 + + types.accelerator_type = types.accelerator_type_v1 + types.annotation = types.annotation_v1 + types.annotation_spec = types.annotation_spec_v1 + types.batch_prediction_job = types.batch_prediction_job_v1 + types.completion_stats = types.completion_stats_v1 + types.custom_job = types.custom_job_v1 + types.data_item = types.data_item_v1 + types.data_labeling_job = types.data_labeling_job_v1 + types.dataset = types.dataset_v1 + types.dataset_service = types.dataset_service_v1 + types.deployed_model_ref = types.deployed_model_ref_v1 + types.encryption_spec = types.encryption_spec_v1 + types.endpoint = types.endpoint_v1 + types.endpoint_service = types.endpoint_service_v1 + types.env_var = types.env_var_v1 + types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1 + types.io = types.io_v1 + types.job_service = types.job_service_v1 + types.job_state = types.job_state_v1 + types.machine_resources = types.machine_resources_v1 + types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1 + types.model = types.model_v1 + types.model_evaluation = types.model_evaluation_v1 + types.model_evaluation_slice = types.model_evaluation_slice_v1 + types.model_service = types.model_service_v1 + types.operation = types.operation_v1 + types.pipeline_service = types.pipeline_service_v1 + types.pipeline_state = types.pipeline_state_v1 + types.prediction_service = types.prediction_service_v1 + types.specialist_pool = types.specialist_pool_v1 + types.specialist_pool_service = types.specialist_pool_service_v1 + types.training_pipeline = types.training_pipeline_v1 + +__all__ = ( + DEFAULT_VERSION, + V1BETA1, + V1, + services, + types, +) diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py new file mode 100644 index 0000000000..0888c27fbb --- /dev/null +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + client as dataset_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + client as pipeline_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + client as prediction_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + client as specialist_pool_service_client_v1beta1, +) + +from google.cloud.aiplatform_v1.services.dataset_service import ( + client as dataset_service_client_v1, +) +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client_v1, +) +from google.cloud.aiplatform_v1.services.job_service import ( + client as job_service_client_v1, +) +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client_v1, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client_v1, +) +from google.cloud.aiplatform_v1.services.prediction_service import ( + client as prediction_service_client_v1, +) +from google.cloud.aiplatform_v1.services.specialist_pool_service import ( + client as specialist_pool_service_client_v1, +) + +__all__ = ( + # v1 + dataset_service_client_v1, + endpoint_service_client_v1, + job_service_client_v1, + model_service_client_v1, + pipeline_service_client_v1, + prediction_service_client_v1, + specialist_pool_service_client_v1, + # v1beta1 + dataset_service_client_v1beta1, + endpoint_service_client_v1beta1, + job_service_client_v1beta1, + model_service_client_v1beta1, + pipeline_service_client_v1beta1, + prediction_service_client_v1beta1, + specialist_pool_service_client_v1beta1, +) diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py new file mode 100644 index 0000000000..d03e0d2f3a --- /dev/null +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform_v1beta1.types import ( + accelerator_type as accelerator_type_v1beta1, + annotation as annotation_v1beta1, + annotation_spec as annotation_spec_v1beta1, + batch_prediction_job as batch_prediction_job_v1beta1, + completion_stats as completion_stats_v1beta1, + custom_job as custom_job_v1beta1, + data_item as data_item_v1beta1, + data_labeling_job as data_labeling_job_v1beta1, + dataset as dataset_v1beta1, + dataset_service as dataset_service_v1beta1, + deployed_model_ref as deployed_model_ref_v1beta1, + encryption_spec as encryption_spec_v1beta1, + endpoint as endpoint_v1beta1, + endpoint_service as endpoint_service_v1beta1, + env_var as env_var_v1beta1, + explanation as explanation_v1beta1, + explanation_metadata as explanation_metadata_v1beta1, + hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1, + io as io_v1beta1, + job_service as job_service_v1beta1, + job_state as job_state_v1beta1, + machine_resources as machine_resources_v1beta1, + manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1, + model as model_v1beta1, + model_evaluation as model_evaluation_v1beta1, + model_evaluation_slice as model_evaluation_slice_v1beta1, + model_service as model_service_v1beta1, + operation as operation_v1beta1, + pipeline_service as pipeline_service_v1beta1, + pipeline_state as pipeline_state_v1beta1, + prediction_service as prediction_service_v1beta1, + specialist_pool as specialist_pool_v1beta1, + specialist_pool_service as specialist_pool_service_v1beta1, + training_pipeline as training_pipeline_v1beta1, +) +from google.cloud.aiplatform_v1.types import ( + accelerator_type as accelerator_type_v1, + annotation as annotation_v1, + annotation_spec as annotation_spec_v1, + batch_prediction_job as batch_prediction_job_v1, + completion_stats as completion_stats_v1, + custom_job as custom_job_v1, + data_item as data_item_v1, + data_labeling_job as data_labeling_job_v1, + dataset as dataset_v1, + dataset_service as dataset_service_v1, + deployed_model_ref as deployed_model_ref_v1, + encryption_spec as encryption_spec_v1, + endpoint as endpoint_v1, + endpoint_service as endpoint_service_v1, + env_var as env_var_v1, + hyperparameter_tuning_job as hyperparameter_tuning_job_v1, + io as io_v1, + job_service as job_service_v1, + job_state as job_state_v1, + machine_resources as machine_resources_v1, + manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1, + model as model_v1, + model_evaluation as model_evaluation_v1, + model_evaluation_slice as model_evaluation_slice_v1, + model_service as model_service_v1, + operation as operation_v1, + pipeline_service as pipeline_service_v1, + pipeline_state as pipeline_state_v1, + prediction_service as prediction_service_v1, + specialist_pool as specialist_pool_v1, + specialist_pool_service as specialist_pool_service_v1, + training_pipeline as training_pipeline_v1, +) + +__all__ = ( + # v1 + accelerator_type_v1, + annotation_v1, + annotation_spec_v1, + batch_prediction_job_v1, + completion_stats_v1, + custom_job_v1, + data_item_v1, + data_labeling_job_v1, + dataset_v1, + dataset_service_v1, + deployed_model_ref_v1, + encryption_spec_v1, + endpoint_v1, + endpoint_service_v1, + env_var_v1, + hyperparameter_tuning_job_v1, + io_v1, + job_service_v1, + job_state_v1, + machine_resources_v1, + manual_batch_tuning_parameters_v1, + model_v1, + model_evaluation_v1, + model_evaluation_slice_v1, + model_service_v1, + operation_v1, + pipeline_service_v1, + pipeline_state_v1, + prediction_service_v1, + specialist_pool_v1, + specialist_pool_service_v1, + training_pipeline_v1, + # v1beta1 + accelerator_type_v1beta1, + annotation_v1beta1, + annotation_spec_v1beta1, + batch_prediction_job_v1beta1, + completion_stats_v1beta1, + custom_job_v1beta1, + data_item_v1beta1, + data_labeling_job_v1beta1, + dataset_v1beta1, + dataset_service_v1beta1, + deployed_model_ref_v1beta1, + encryption_spec_v1beta1, + endpoint_v1beta1, + endpoint_service_v1beta1, + env_var_v1beta1, + explanation_v1beta1, + explanation_metadata_v1beta1, + hyperparameter_tuning_job_v1beta1, + io_v1beta1, + job_service_v1beta1, + job_state_v1beta1, + machine_resources_v1beta1, + manual_batch_tuning_parameters_v1beta1, + model_v1beta1, + model_evaluation_v1beta1, + model_evaluation_slice_v1beta1, + model_service_v1beta1, + operation_v1beta1, + pipeline_service_v1beta1, + pipeline_state_v1beta1, + prediction_service_v1beta1, + specialist_pool_v1beta1, + specialist_pool_service_v1beta1, + training_pipeline_v1beta1, +) diff --git a/google/cloud/aiplatform/constants.py b/google/cloud/aiplatform/constants.py new file mode 100644 index 0000000000..62c28009c2 --- /dev/null +++ b/google/cloud/aiplatform/constants.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DEFAULT_REGION = "us-central1" +SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1") +API_BASE_PATH = "aiplatform.googleapis.com" + +# Batch Prediction +BATCH_PREDICTION_INPUT_STORAGE_FORMATS = ( + "jsonl", + "csv", + "tf-record", + "tf-record-gzip", + "bigquery", + "file-list", +) +BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS = ("jsonl", "csv", "bigquery") + +MOBILE_TF_MODEL_TYPES = { + "MOBILE_TF_LOW_LATENCY_1", + "MOBILE_TF_VERSATILE_1", + "MOBILE_TF_HIGH_ACCURACY_1", +} + +# TODO(b/177079208): Use EPCL Enums for validating Model Types +# Defined by gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_* +# Format: "prediction_type": set() of model_type's +# +# NOTE: When adding a new prediction_type's, ensure it fits the pattern +# "automl_image_{prediction_type}_*" used by the YAML schemas on GCS +AUTOML_IMAGE_PREDICTION_MODEL_TYPES = { + "classification": {"CLOUD"} | MOBILE_TF_MODEL_TYPES, + "object_detection": {"CLOUD_HIGH_ACCURACY_1", "CLOUD_LOW_LATENCY_1"} + | MOBILE_TF_MODEL_TYPES, +} + +AUTOML_VIDEO_PREDICTION_MODEL_TYPES = { + "classification": {"CLOUD"} | {"MOBILE_VERSATILE_1"}, + "action_recognition": {"CLOUD"} | {"MOBILE_VERSATILE_1"}, + "object_tracking": {"CLOUD"} + | { + "MOBILE_VERSATILE_1", + "MOBILE_CORAL_VERSATILE_1", + "MOBILE_CORAL_LOW_LATENCY_1", + "MOBILE_JETSON_VERSATILE_1", + "MOBILE_JETSON_LOW_LATENCY_1", + }, +} diff --git a/google/cloud/aiplatform/datasets/__init__.py b/google/cloud/aiplatform/datasets/__init__.py new file mode 100644 index 0000000000..57e2bad45d --- /dev/null +++ b/google/cloud/aiplatform/datasets/__init__.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.datasets.dataset import _Dataset +from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset +from google.cloud.aiplatform.datasets.image_dataset import ImageDataset +from google.cloud.aiplatform.datasets.text_dataset import TextDataset +from google.cloud.aiplatform.datasets.video_dataset import VideoDataset + + +__all__ = ( + "_Dataset", + "TabularDataset", + "ImageDataset", + "TextDataset", + "VideoDataset", +) diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py new file mode 100644 index 0000000000..eefd1b04fd --- /dev/null +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +from typing import Optional, Dict, Sequence, Union +from google.cloud.aiplatform import schema + +from google.cloud.aiplatform.compat.types import ( + io as gca_io, + dataset as gca_dataset, +) + + +class Datasource(abc.ABC): + """An abstract class that sets dataset_metadata""" + + @property + @abc.abstractmethod + def dataset_metadata(self): + """Dataset Metadata.""" + pass + + +class DatasourceImportable(abc.ABC): + """An abstract class that sets import_data_config""" + + @property + @abc.abstractmethod + def import_data_config(self): + """Import Data Config.""" + pass + + +class TabularDatasource(Datasource): + """Datasource for creating a tabular dataset for AI Platform""" + + def __init__( + self, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + ): + """Creates a tabular datasource + + Args: + gcs_source (Union[str, Sequence[str]]): + Cloud Storage URI of one or more files. Only CSV files are supported. + The first line of the CSV file is used as the header. + If there are multiple files, the header is the first line of + the lexicographically first file, the other files must either + contain the exact same header or omit the header. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + The URI of a BigQuery table. + example: + "bq://project.dataset.table_name" + + Raises: + ValueError if source configuration is not valid. + """ + + dataset_metadata = None + + if gcs_source and isinstance(gcs_source, str): + gcs_source = [gcs_source] + + if gcs_source and bq_source: + raise ValueError("Only one of gcs_source or bq_source can be set.") + + if not any([gcs_source, bq_source]): + raise ValueError("One of gcs_source or bq_source must be set.") + + if gcs_source: + dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}} + elif bq_source: + dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}} + + self._dataset_metadata = dataset_metadata + + @property + def dataset_metadata(self) -> Optional[Dict]: + """Dataset Metadata.""" + return self._dataset_metadata + + +class NonTabularDatasource(Datasource): + """Datasource for creating an empty non-tabular dataset for AI Platform""" + + @property + def dataset_metadata(self) -> Optional[Dict]: + return None + + +class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable): + """Datasource for creating a non-tabular dataset for AI Platform and importing data to the dataset""" + + def __init__( + self, + gcs_source: Union[str, Sequence[str]], + import_schema_uri: str, + data_item_labels: Optional[Dict] = None, + ): + """Creates a non-tabular datasource + + Args: + gcs_source (Union[str, Sequence[str]]): + Required. The Google Cloud Storage location for the input content. + Google Cloud Storage URI(-s) to the input file(s). May contain + wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + """ + super().__init__() + self._gcs_source = [gcs_source] if isinstance(gcs_source, str) else gcs_source + self._import_schema_uri = import_schema_uri + self._data_item_labels = data_item_labels + + @property + def import_data_config(self) -> gca_dataset.ImportDataConfig: + """Import Data Config.""" + return gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=self._gcs_source), + import_schema_uri=self._import_schema_uri, + data_item_labels=self._data_item_labels, + ) + + +def create_datasource( + metadata_schema_uri: str, + import_schema_uri: Optional[str] = None, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + data_item_labels: Optional[Dict] = None, +) -> Datasource: + """Creates a datasource + Args: + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + gcs_source (Union[str, Sequence[str]]): + The Google Cloud Storage location for the input content. + Google Cloud Storage URI(-s) to the input file(s). May contain + wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + + Returns: + datasource (Datasource) + + Raises: + ValueError when below scenarios happen + - import_schema_uri is identified for creating TabularDatasource + - either import_schema_uri or gcs_source is missing for creating NonTabularDatasourceImportable + """ + + if metadata_schema_uri == schema.dataset.metadata.tabular: + if import_schema_uri: + raise ValueError("tabular dataset does not support data import.") + return TabularDatasource(gcs_source, bq_source) + + if not import_schema_uri and not gcs_source: + return NonTabularDatasource() + elif import_schema_uri and gcs_source: + return NonTabularDatasourceImportable( + gcs_source, import_schema_uri, data_item_labels + ) + else: + raise ValueError( + "nontabular dataset requires both import_schema_uri and gcs_source for data import." + ) diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py new file mode 100644 index 0000000000..25078ab2c5 --- /dev/null +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -0,0 +1,577 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union, List + +from google.api_core import operation +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.services import dataset_service_client +from google.cloud.aiplatform.compat.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + io as gca_io, +) +from google.cloud.aiplatform.datasets import _datasources + +_LOGGER = base.Logger(__name__) + + +class _Dataset(base.AiPlatformResourceNounWithFutureManager): + """Managed dataset resource for AI Platform""" + + client_class = utils.DatasetClientWithOverride + _is_client_prediction_client = False + _resource_noun = "datasets" + _getter_method = "get_dataset" + _list_method = "list_datasets" + _delete_method = "delete_dataset" + + _supported_metadata_schema_uris: Tuple[str] = () + + def __init__( + self, + dataset_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed dataset given a dataset name or ID. + + Args: + dataset_name (str): + Required. A fully-qualified dataset resource name or dataset ID. + Example: "projects/123/locations/us-central1/datasets/456" or + "456" when project and location are initialized or passed. + project (str): + Optional project to retrieve dataset from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve dataset from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=dataset_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=dataset_name) + self._validate_metadata_schema_uri() + + @property + def metadata_schema_uri(self) -> str: + """The metadata schema uri of this dataset resource.""" + return self._gca_resource.metadata_schema_uri + + def _validate_metadata_schema_uri(self) -> None: + """Validate the metadata_schema_uri of retrieved dataset resource. + + Raises: + ValueError if the dataset type of the retrieved dataset resource is + not supported by the class. + """ + if self._supported_metadata_schema_uris and ( + self.metadata_schema_uri not in self._supported_metadata_schema_uris + ): + raise ValueError( + f"{self.__class__.__name__} class can not be used to retrieve " + f"dataset resource {self.resource_name}, check the dataset type" + ) + + @classmethod + def create( + cls, + display_name: str, + metadata_schema_uri: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "_Dataset": + """Creates a new dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + dataset (Dataset): + Instantiated representation of the managed dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + bq_source=bq_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + @classmethod + @base.optional_sync() + def _create_and_import( + cls, + api_client: dataset_service_client.DatasetServiceClient, + parent: str, + display_name: str, + metadata_schema_uri: str, + datasource: _datasources.Datasource, + project: str, + location: str, + credentials: Optional[auth_credentials.Credentials], + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, + sync: bool = True, + ) -> "_Dataset": + """Creates a new dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + api_client (dataset_service_client.DatasetServiceClient): + An instance of DatasetServiceClient with the correct api_endpoint + already set based on user's preferences. + parent (str): + Required. Also known as common location path, that usually contains the + project and location that the user provided to the upstream method. + Example: "projects/my-prj/locations/us-central1" + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + datasource (_datasources.Datasource): + Required. Datasource for creating a dataset for AI Platform. + project (str): + Required. Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Required. Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]): + Optional. The Cloud KMS customer managed encryption key used to protect the dataset. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + dataset (Dataset): + Instantiated representation of the managed dataset resource. + """ + + create_dataset_lro = cls._create( + api_client=api_client, + parent=parent, + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + request_metadata=request_metadata, + encryption_spec=encryption_spec, + ) + + _LOGGER.log_create_with_lro(cls, create_dataset_lro) + + created_dataset = create_dataset_lro.result() + + _LOGGER.log_create_complete(cls, created_dataset, "ds") + + dataset_obj = cls( + dataset_name=created_dataset.name, + project=project, + location=location, + credentials=credentials, + ) + + # Import if import datasource is DatasourceImportable + if isinstance(datasource, _datasources.DatasourceImportable): + dataset_obj._import_and_wait(datasource) + + return dataset_obj + + def _import_and_wait(self, datasource): + _LOGGER.log_action_start_against_resource( + "Importing", "data", self, + ) + + import_lro = self._import(datasource=datasource) + + _LOGGER.log_action_started_against_resource_with_lro( + "Import", "data", self.__class__, import_lro + ) + + import_lro.result() + + _LOGGER.log_action_completed_against_resource("data", "imported", self) + + @classmethod + def _create( + cls, + api_client: dataset_service_client.DatasetServiceClient, + parent: str, + display_name: str, + metadata_schema_uri: str, + datasource: _datasources.Datasource, + request_metadata: Sequence[Tuple[str, str]] = (), + encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, + ) -> operation.Operation: + """Creates a new managed dataset by directly calling API client. + + Args: + api_client (dataset_service_client.DatasetServiceClient): + An instance of DatasetServiceClient with the correct api_endpoint + already set based on user's preferences. + parent (str): + Required. Also known as common location path, that usually contains the + project and location that the user provided to the upstream method. + Example: "projects/my-prj/locations/us-central1" + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + datasource (_datasources.Datasource): + Required. Datasource for creating a dataset for AI Platform. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the create_dataset + request as metadata. Usually to specify special dataset config. + encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]): + Optional. The Cloud KMS customer managed encryption key used to protect the dataset. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + Returns: + operation (Operation): + An object representing a long-running operation. + """ + + gapic_dataset = gca_dataset.Dataset( + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + metadata=datasource.dataset_metadata, + encryption_spec=encryption_spec, + ) + + return api_client.create_dataset( + parent=parent, dataset=gapic_dataset, metadata=request_metadata + ) + + def _import( + self, datasource: _datasources.DatasourceImportable, + ) -> operation.Operation: + """Imports data into managed dataset by directly calling API client. + + Args: + datasource (_datasources.DatasourceImportable): + Required. Datasource for importing data to an existing dataset for AI Platform. + + Returns: + operation (Operation): + An object representing a long-running operation. + """ + return self.api_client.import_data( + name=self.resource_name, import_configs=[datasource.import_data_config] + ) + + @base.optional_sync(return_input_arg="self") + def import_data( + self, + gcs_source: Union[str, Sequence[str]], + import_schema_uri: str, + data_item_labels: Optional[Dict] = None, + sync: bool = True, + ) -> "_Dataset": + """Upload data to existing managed dataset. + + Args: + gcs_source (Union[str, Sequence[str]]): + Required. Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + dataset (Dataset): + Instantiated representation of the managed dataset resource. + """ + datasource = _datasources.create_datasource( + metadata_schema_uri=self.metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + self._import_and_wait(datasource=datasource) + return self + + # TODO(b/174751568) add optional sync support + def export_data(self, output_dir: str) -> Sequence[str]: + """Exports data to output dir to GCS. + + Args: + output_dir (str): + Required. The Google Cloud Storage location where the output is to + be written to. In the given directory a new directory will be + created with name: + ``export-data--`` + where timestamp is in YYYYMMDDHHMMSS format. All export + output will be written into that directory. Inside that + directory, annotations with the same schema will be grouped + into sub directories which are named with the corresponding + annotations' schema title. Inside these sub directories, a + schema.yaml will be created to describe the output format. + + If the uri doesn't end with '/', a '/' will be automatically + appended. The directory is created if it doesn't exist. + + Returns: + exported_files (Sequence[str]): + All of the files that are exported in this export operation. + """ + self.wait() + + # TODO(b/171311614): Add support for BiqQuery export path + export_data_config = gca_dataset.ExportDataConfig( + gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir) + ) + + _LOGGER.log_action_start_against_resource("Exporting", "data", self) + + export_lro = self.api_client.export_data( + name=self.resource_name, export_config=export_data_config + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Export", "data", self.__class__, export_lro + ) + + export_data_response = export_lro.result() + + _LOGGER.log_action_completed_against_resource("data", "export", self) + + return export_data_response.exported_files + + def update(self): + raise NotImplementedError("Update dataset has not been implemented yet") + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Dataset resource. + + Example Usage: + + aiplatform.TabularDataset.list( + filter='labels.my_key="my_value"', + order_by='display_name' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[base.AiPlatformResourceNoun] - A list of Dataset resource objects + """ + + dataset_subclass_filter = ( + lambda gapic_obj: gapic_obj.metadata_schema_uri + in cls._supported_metadata_schema_uris + ) + + return cls._list_with_local_order( + cls_filter=dataset_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/datasets/image_dataset.py b/google/cloud/aiplatform/datasets/image_dataset.py new file mode 100644 index 0000000000..32db96bea1 --- /dev/null +++ b/google/cloud/aiplatform/datasets/image_dataset.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class ImageDataset(datasets._Dataset): + """Managed image dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.image, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "ImageDataset": + """Creates a new image dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + image_dataset (ImageDataset): + Instantiated representation of the managed image dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.image + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py new file mode 100644 index 0000000000..3dd217aad7 --- /dev/null +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class TabularDataset(datasets._Dataset): + """Managed tabular dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.tabular, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "TabularDataset": + """Creates a new tabular dataset. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + tabular_dataset (TabularDataset): + Instantiated representation of the managed tabular dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.tabular + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + gcs_source=gcs_source, + bq_source=bq_source, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + def import_data(self): + raise NotImplementedError( + f"{self.__class__.__name__} class does not support 'import_data'" + ) diff --git a/google/cloud/aiplatform/datasets/text_dataset.py b/google/cloud/aiplatform/datasets/text_dataset.py new file mode 100644 index 0000000000..c27fed59ad --- /dev/null +++ b/google/cloud/aiplatform/datasets/text_dataset.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class TextDataset(datasets._Dataset): + """Managed text dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.text, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "TextDataset": + """Creates a new text dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Example Usage: + ds = aiplatform.TextDataset.create( + display_name='my-dataset', + gcs_source='gs://my-bucket/dataset.csv', + import_schema_uri=aiplatform.schema.dataset.ioformat.text.multi_label_classification + ) + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + text_dataset (TextDataset): + Instantiated representation of the managed text dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.text + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) diff --git a/google/cloud/aiplatform/datasets/video_dataset.py b/google/cloud/aiplatform/datasets/video_dataset.py new file mode 100644 index 0000000000..84af000df4 --- /dev/null +++ b/google/cloud/aiplatform/datasets/video_dataset.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class VideoDataset(datasets._Dataset): + """Managed video dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.video, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "VideoDataset": + """Creates a new video dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + video_dataset (VideoDataset): + Instantiated representation of the managed video dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.video + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) diff --git a/google/cloud/aiplatform/explain/__init__.py b/google/cloud/aiplatform/explain/__init__.py new file mode 100644 index 0000000000..61b9181834 --- /dev/null +++ b/google/cloud/aiplatform/explain/__init__.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat.types import ( + explanation_metadata_v1beta1 as explanation_metadata, + explanation_v1beta1 as explanation, +) + +ExplanationMetadata = explanation_metadata.ExplanationMetadata + +# ExplanationMetadata subclasses +InputMetadata = ExplanationMetadata.InputMetadata +OutputMetadata = ExplanationMetadata.OutputMetadata + +# InputMetadata subclasses +Encoding = InputMetadata.Encoding +FeatureValueDomain = InputMetadata.FeatureValueDomain +Visualization = InputMetadata.Visualization + + +ExplanationParameters = explanation.ExplanationParameters +FeatureNoiseSigma = explanation.FeatureNoiseSigma + +# Classes used by ExplanationParameters +IntegratedGradientsAttribution = explanation.IntegratedGradientsAttribution + +SampledShapleyAttribution = explanation.SampledShapleyAttribution +SmoothGradConfig = explanation.SmoothGradConfig +XraiAttribution = explanation.XraiAttribution + + +__all__ = ( + "Encoding", + "ExplanationMetadata", + "ExplanationParameters", + "FeatureNoiseSigma", + "FeatureValueDomain", + "InputMetadata", + "IntegratedGradientsAttribution", + "OutputMetadata", + "SampledShapleyAttribution", + "SmoothGradConfig", + "Visualization", + "XraiAttribution", +) diff --git a/google/cloud/aiplatform/helpers/_decorators.py b/google/cloud/aiplatform/helpers/_decorators.py index 5d9aa28bea..95aac31c4f 100644 --- a/google/cloud/aiplatform/helpers/_decorators.py +++ b/google/cloud/aiplatform/helpers/_decorators.py @@ -68,3 +68,5 @@ def _from_map(map_): marshal = Marshal(name="google.cloud.aiplatform.v1beta1") marshal.register(Value, ConversionValueRule(marshal=marshal)) +marshal = Marshal(name="google.cloud.aiplatform.v1") +marshal.register(Value, ConversionValueRule(marshal=marshal)) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py new file mode 100644 index 0000000000..b84a006d02 --- /dev/null +++ b/google/cloud/aiplatform/initializer.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from concurrent import futures +import logging +import pkg_resources +import os +from typing import Optional, Type, Union + +from google.api_core import client_options +from google.api_core import gapic_v1 +import google.auth +from google.auth import credentials as auth_credentials +from google.auth.exceptions import GoogleAuthError + +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.types import ( + encryption_spec as gca_encryption_spec_compat, + encryption_spec_v1 as gca_encryption_spec_v1, + encryption_spec_v1beta1 as gca_encryption_spec_v1beta1, +) + + +class _Config: + """Stores common parameters and options for API calls.""" + + def __init__(self): + self._project = None + self._experiment = None + self._location = None + self._staging_bucket = None + self._credentials = None + self._encryption_spec_key_name = None + + def init( + self, + *, + project: Optional[str] = None, + location: Optional[str] = None, + experiment: Optional[str] = None, + staging_bucket: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + ): + """Updates common initalization parameters with provided options. + + Args: + project (str): The default project to use when making API calls. + location (str): The default location to use when making API calls. If not + set defaults to us-central-1 + experiment (str): The experiment to assign + staging_bucket (str): The default staging bucket to use to stage artifacts + when making API calls. In the form gs://... + credentials (google.auth.crendentials.Credentials): The default custom + credentials to use when making API calls. If not provided crendentials + will be ascertained from the environment. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect a resource. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this resource and all sub-resources will be secured by this key. + """ + if project: + self._project = project + if location: + utils.validate_region(location) + self._location = location + if experiment: + logging.warning("Experiments currently not supported.") + self._experiment = experiment + if staging_bucket: + self._staging_bucket = staging_bucket + if credentials: + self._credentials = credentials + if encryption_spec_key_name: + self._encryption_spec_key_name = encryption_spec_key_name + + def get_encryption_spec( + self, + encryption_spec_key_name: Optional[str], + select_version: Optional[str] = compat.DEFAULT_VERSION, + ) -> Optional[ + Union[ + gca_encryption_spec_v1.EncryptionSpec, + gca_encryption_spec_v1beta1.EncryptionSpec, + ] + ]: + """Creates a gca_encryption_spec.EncryptionSpec instance from the given key name. + If the provided key name is None, it uses the default key name if provided. + + Args: + encryption_spec_key_name (Optional[str]): The default encryption key name to use when creating resources. + select_version: The default version is set to compat.DEFAULT_VERSION + """ + kms_key_name = encryption_spec_key_name or self.encryption_spec_key_name + encryption_spec = None + if kms_key_name: + gca_encryption_spec = gca_encryption_spec_compat + if select_version == compat.V1BETA1: + gca_encryption_spec = gca_encryption_spec_v1beta1 + encryption_spec = gca_encryption_spec.EncryptionSpec( + kms_key_name=kms_key_name + ) + return encryption_spec + + @property + def project(self) -> str: + """Default project.""" + if self._project: + return self._project + + project_not_found_exception_str = ( + "Unable to find your project. Please provide a project ID by:" + "\n- Passing a constructor argument" + "\n- Using aiplatform.init()" + "\n- Setting a GCP environment variable" + ) + + try: + _, project_id = google.auth.default() + except GoogleAuthError: + raise GoogleAuthError(project_not_found_exception_str) + + if not project_id: + raise ValueError(project_not_found_exception_str) + + return project_id + + @property + def location(self) -> str: + """Default location.""" + return self._location or constants.DEFAULT_REGION + + @property + def experiment(self) -> Optional[str]: + """Default experiment, if provided.""" + return self._experiment + + @property + def staging_bucket(self) -> Optional[str]: + """Default staging bucket, if provided.""" + return self._staging_bucket + + @property + def credentials(self) -> Optional[auth_credentials.Credentials]: + """Default credentials.""" + if self._credentials: + return self._credentials + logger = logging.getLogger("google.auth._default") + logging_warning_filter = utils.LoggingWarningFilter() + logger.addFilter(logging_warning_filter) + credentials, _ = google.auth.default() + logger.removeFilter(logging_warning_filter) + return credentials + + @property + def encryption_spec_key_name(self) -> Optional[str]: + """Default encryption spec key name, if provided.""" + return self._encryption_spec_key_name + + def get_client_options( + self, location_override: Optional[str] = None + ) -> client_options.ClientOptions: + """Creates GAPIC client_options using location and type. + + Args: + location_override (str): + Set this parameter to get client options for a location different from + location set by initializer. Must be a GCP region supported by AI + Platform (Unified). + + Returns: + clients_options (google.api_core.client_options.ClientOptions): + A ClientOptions object set with regionalized API endpoint, i.e. + { "api_endpoint": "us-central1-aiplatform.googleapis.com" } or + { "api_endpoint": "asia-east1-aiplatform.googleapis.com" } + """ + if not (self.location or location_override): + raise ValueError( + "No location found. Provide or initialize SDK with a location." + ) + + region = location_override or self.location + region = region.lower() + + utils.validate_region(region) + + return client_options.ClientOptions( + api_endpoint=f"{region}-{constants.API_BASE_PATH}" + ) + + def common_location_path( + self, project: Optional[str] = None, location: Optional[str] = None + ) -> str: + """Get parent resource with optional project and location override. + + Args: + project (str): GCP project. If not provided will use the current project. + location (str): Location. If not provided will use the current location. + Returns: + resource_parent: Formatted parent resource string. + """ + if location: + utils.validate_region(location) + + return "/".join( + [ + "projects", + project or self.project, + "locations", + location or self.location, + ] + ) + + def create_client( + self, + client_class: Type[utils.AiPlatformServiceClientWithOverride], + credentials: Optional[auth_credentials.Credentials] = None, + location_override: Optional[str] = None, + prediction_client: bool = False, + ) -> utils.AiPlatformServiceClientWithOverride: + """Instantiates a given AiPlatformServiceClient with optional overrides. + + Args: + client_class (utils.AiPlatformServiceClientWithOverride): + (Required) An AI Platform Service Client with optional overrides. + credentials (auth_credentials.Credentials): + Custom auth credentials. If not provided will use the current config. + location_override (str): Optional location override. + prediction_client (str): Optional flag to use a prediction endpoint. + Returns: + client: Instantiated AI Platform Service client with optional overrides + """ + gapic_version = pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version + client_info = gapic_v1.client_info.ClientInfo( + gapic_version=gapic_version, user_agent=f"model-builder/{gapic_version}" + ) + + kwargs = { + "credentials": credentials or self.credentials, + "client_options": self.get_client_options( + location_override=location_override + ), + "client_info": client_info, + } + + return client_class(**kwargs) + + +# global config to store init parameters: ie, aiplatform.init(project=..., location=...) +global_config = _Config() + +global_pool = futures.ThreadPoolExecutor( + max_workers=min(32, max(4, (os.cpu_count() or 0) * 5)) +) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py new file mode 100644 index 0000000000..a7f2bbd31d --- /dev/null +++ b/google/cloud/aiplatform/jobs.py @@ -0,0 +1,795 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Iterable, Optional, Union, Sequence, Dict, List + +import abc +import sys +import time +import logging + +from google.cloud import storage +from google.cloud import bigquery + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.types import ( + io as gca_io_compat, + io_v1beta1 as gca_io_v1beta1, + job_state as gca_job_state, + batch_prediction_job as gca_bp_job_compat, + batch_prediction_job_v1 as gca_bp_job_v1, + batch_prediction_job_v1beta1 as gca_bp_job_v1beta1, + machine_resources as gca_machine_resources_compat, + machine_resources_v1beta1 as gca_machine_resources_v1beta1, + explanation_v1beta1 as gca_explanation_v1beta1, +) + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) +_LOGGER = base.Logger(__name__) + +_JOB_COMPLETE_STATES = ( + gca_job_state.JobState.JOB_STATE_SUCCEEDED, + gca_job_state.JobState.JOB_STATE_FAILED, + gca_job_state.JobState.JOB_STATE_CANCELLED, + gca_job_state.JobState.JOB_STATE_PAUSED, +) + +_JOB_ERROR_STATES = ( + gca_job_state.JobState.JOB_STATE_FAILED, + gca_job_state.JobState.JOB_STATE_CANCELLED, +) + + +class _Job(base.AiPlatformResourceNounWithFutureManager): + """ + Class that represents a general Job resource in AI Platform (Unified). + Cannot be directly instantiated. + + Serves as base class to specific Job types, i.e. BatchPredictionJob or + DataLabelingJob to re-use shared functionality. + + Subclasses requires one class attribute: + + _getter_method (str): The name of JobServiceClient getter method for specific + Job type, i.e. 'get_custom_job' for CustomJob + _cancel_method (str): The name of the specific JobServiceClient cancel method + _delete_method (str): The name of the specific JobServiceClient delete method + """ + + client_class = utils.JobpointClientWithOverride + _is_client_prediction_client = False + + def __init__( + self, + job_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """ + Retrives Job subclass resource by calling a subclass-specific getter method. + + Args: + job_name (str): + Required. A fully-qualified job resource name or job ID. + Example: "projects/123/locations/us-central1/batchPredictionJobs/456" or + "456" when project, location and job_type are initialized or passed. + project: Optional[str] = None, + Optional project to retrieve Job subclass from. If not set, + project set in aiplatform.init will be used. + location: Optional[str] = None, + Optional location to retrieve Job subclass from. If not set, + location set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials] = None, + Custom credentials to use. If not set, credentials set in + aiplatform.init will be used. + """ + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=job_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=job_name) + + @property + def state(self) -> gca_job_state.JobState: + """Fetch Job again and return the current JobState. + + Returns: + state (job_state.JobState): + Enum that describes the state of a AI Platform job. + """ + + # Fetch the Job again for most up-to-date job state + self._sync_gca_resource() + + return self._gca_resource.state + + @property + @abc.abstractmethod + def _job_type(cls) -> str: + """Job type.""" + pass + + @property + @abc.abstractmethod + def _cancel_method(cls) -> str: + """Name of cancellation method for cancelling the specific job type.""" + pass + + def _dashboard_uri(self) -> Optional[str]: + """Helper method to compose the dashboard uri where job can be viewed.""" + fields = utils.extract_fields_from_resource_name(self.resource_name) + url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}" + return url + + def _block_until_complete(self): + """Helper method to block and check on job until complete. + + Raises: + RuntimeError: If job failed or cancelled. + + """ + + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + log_wait = 5 + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + previous_time = time.time() + while self.state not in _JOB_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + time.sleep(wait) + + _LOGGER.log_action_completed_against_resource("", "run", self) + + # Error is only populated when the job state is + # JOB_STATE_FAILED or JOB_STATE_CANCELLED. + if self.state in _JOB_ERROR_STATES: + raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Job Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of Job resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + def cancel(self) -> None: + """Cancels this Job. Success of cancellation is not guaranteed. Use `Job.state` + property to verify if cancellation was successful.""" + + _LOGGER.log_action_start_against_resource("Cancelling", "run", self) + getattr(self.api_client, self._cancel_method)(name=self.resource_name) + + +class BatchPredictionJob(_Job): + + _resource_noun = "batchPredictionJobs" + _getter_method = "get_batch_prediction_job" + _list_method = "list_batch_prediction_jobs" + _cancel_method = "cancel_batch_prediction_job" + _delete_method = "delete_batch_prediction_job" + _job_type = "batch-predictions" + + def __init__( + self, + batch_prediction_job_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """ + Retrieves a BatchPredictionJob resource and instantiates its representation. + + Args: + batch_prediction_job_name (str): + Required. A fully-qualified BatchPredictionJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" or + "456" when project and location are initialized or passed. + project: Optional[str] = None, + Optional project to retrieve BatchPredictionJob from. If not set, + project set in aiplatform.init will be used. + location: Optional[str] = None, + Optional location to retrieve BatchPredictionJob from. If not set, + location set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials] = None, + Custom credentials to use. If not set, credentials set in + aiplatform.init will be used. + """ + + super().__init__( + job_name=batch_prediction_job_name, + project=project, + location=location, + credentials=credentials, + ) + + @classmethod + def create( + cls, + job_display_name: str, + model_name: str, + instances_format: str = "jsonl", + predictions_format: str = "jsonl", + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bigquery_source: Optional[str] = None, + gcs_destination_prefix: Optional[str] = None, + bigquery_destination_prefix: Optional[str] = None, + model_parameters: Optional[Dict] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + starting_replica_count: Optional[int] = None, + max_replica_count: Optional[int] = None, + generate_explanation: Optional[bool] = False, + explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, + explanation_parameters: Optional[ + "aiplatform.explain.ExplanationParameters" + ] = None, + labels: Optional[dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "BatchPredictionJob": + """Create a batch prediction job. + + Args: + job_display_name (str): + Required. The user-defined name of the BatchPredictionJob. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + model_name (str): + Required. A fully-qualified model resource name or model ID. + Example: "projects/123/locations/us-central1/models/456" or + "456" when project and location are initialized or passed. + instances_format (str): + Required. The format in which instances are given, must be one + of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip", + or "file-list". Default is "jsonl" when using `gcs_source`. If a + `bigquery_source` is provided, this is overriden to "bigquery". + predictions_format (str): + Required. The format in which AI Platform gives the + predictions, must be one of "jsonl", "csv", or "bigquery". + Default is "jsonl" when using `gcs_destination_prefix`. If a + `bigquery_destination_prefix` is provided, this is overriden to + "bigquery". + gcs_source (Optional[Sequence[str]]): + Google Cloud Storage URI(-s) to your instances to run + batch prediction on. They must match `instances_format`. + May contain wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + bigquery_source (Optional[str]): + BigQuery URI to a table, up to 2000 characters long. For example: + `projectId.bqDatasetId.bqTableId` + gcs_destination_prefix (Optional[str]): + The Google Cloud Storage location of the directory where the + output is to be written to. In the given directory a new + directory is created. Its name is + ``prediction--``, where + timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. + Inside of it files ``predictions_0001.``, + ``predictions_0002.``, ..., + ``predictions_N.`` are created where + ```` depends on chosen ``predictions_format``, + and N may equal 0001 and depends on the total number of + successfully predicted instances. If the Model has both + ``instance`` and ``prediction`` schemata defined then each such + file contains predictions as per the ``predictions_format``. + If prediction for any instance failed (partially or + completely), then an additional ``errors_0001.``, + ``errors_0002.``,..., ``errors_N.`` + files are created (N depends on total number of failed + predictions). These files contain the failed instances, as + per their schema, followed by an additional ``error`` field + which as value has ```google.rpc.Status`` `__ + containing only ``code`` and ``message`` fields. + bigquery_destination_prefix (Optional[str]): + The BigQuery project location where the output is to be + written to. In the given project a new dataset is created + with name + ``prediction__`` where + is made BigQuery-dataset-name compatible (for example, most + special characters become underscores), and timestamp is in + YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the + dataset two tables will be created, ``predictions``, and + ``errors``. If the Model has both ``instance`` and ``prediction`` + schemata defined then the tables have columns as follows: + The ``predictions`` table contains instances for which the + prediction succeeded, it has columns as per a concatenation + of the Model's instance and prediction schemata. The + ``errors`` table contains rows for which the prediction has + failed, it has instance columns, as per the instance schema, + followed by a single "errors" column, which as values has + ```google.rpc.Status`` `__ represented as a STRUCT, + and containing only ``code`` and ``message``. + model_parameters (Optional[Dict]): + The parameters that govern the predictions. The schema of + the parameters may be specified via the Model's `parameters_schema_uri`. + machine_type (Optional[str]): + The type of machine for running batch prediction on + dedicated resources. Not specifying machine type will result in + batch prediction job being run with automatic resources. + accelerator_type (Optional[str]): + The type of accelerator(s) that may be attached + to the machine as per `accelerator_count`. Only used if + `machine_type` is set. + accelerator_count (Optional[int]): + The number of accelerators to attach to the + `machine_type`. Only used if `machine_type` is set. + starting_replica_count (Optional[int]): + The number of machine replicas used at the start of the batch + operation. If not set, AI Platform decides starting number, not + greater than `max_replica_count`. Only used if `machine_type` is + set. + max_replica_count (Optional[int]): + The maximum number of machine replicas the batch operation may + be scaled to. Only used if `machine_type` is set. + Default is 10. + generate_explanation (bool): + Optional. Generate explanation along with the batch prediction + results. This will cause the batch prediction output to include + explanations based on the `prediction_format`: + - `bigquery`: output includes a column named `explanation`. The value + is a struct that conforms to the [aiplatform.gapic.Explanation] object. + - `jsonl`: The JSON objects on each line include an additional entry + keyed `explanation`. The value of the entry is a JSON object that + conforms to the [aiplatform.gapic.Explanation] object. + - `csv`: Generating explanations for CSV format is not supported. + explanation_metadata (aiplatform.explain.ExplanationMetadata): + Optional. Explanation metadata configuration for this BatchPredictionJob. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_metadata`. + All fields of `explanation_metadata` are optional in the request. If + a field of the `explanation_metadata` object is not populated, the + corresponding field of the `Model.explanation_metadata` object is inherited. + For more details, see `Ref docs ` + explanation_parameters (aiplatform.explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_parameters`. + All fields of `explanation_parameters` are optional in the request. If + a field of the `explanation_parameters` object is not populated, the + corresponding field of the `Model.explanation_parameters` object is inherited. + For more details, see `Ref docs ` + labels (Optional[dict]): + The labels with user-defined metadata to organize your + BatchPredictionJobs. Label keys and values can be no longer than + 64 characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information and examples of labels. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to create this batch prediction + job. Overrides credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the job. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If this is set, then all + resources created by the BatchPredictionJob will + be encrypted with the provided encryption key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + + """ + + utils.validate_display_name(job_display_name) + + model_name = utils.full_resource_name( + resource_name=model_name, + resource_noun="models", + project=project, + location=location, + ) + + # Raise error if both or neither source URIs are provided + if bool(gcs_source) == bool(bigquery_source): + raise ValueError( + "Please provide either a gcs_source or bigquery_source, " + "but not both." + ) + + # Raise error if both or neither destination prefixes are provided + if bool(gcs_destination_prefix) == bool(bigquery_destination_prefix): + raise ValueError( + "Please provide either a gcs_destination_prefix or " + "bigquery_destination_prefix, but not both." + ) + + # Raise error if unsupported instance format is provided + if instances_format not in constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS: + raise ValueError( + f"{predictions_format} is not an accepted instances format " + f"type. Please choose from: {constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS}" + ) + + # Raise error if unsupported prediction format is provided + if predictions_format not in constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS: + raise ValueError( + f"{predictions_format} is not an accepted prediction format " + f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" + ) + gca_bp_job = gca_bp_job_compat + gca_io = gca_io_compat + gca_machine_resources = gca_machine_resources_compat + select_version = compat.DEFAULT_VERSION + if generate_explanation: + gca_bp_job = gca_bp_job_v1beta1 + gca_io = gca_io_v1beta1 + gca_machine_resources = gca_machine_resources_v1beta1 + select_version = compat.V1BETA1 + + gapic_batch_prediction_job = gca_bp_job.BatchPredictionJob() + + # Required Fields + gapic_batch_prediction_job.display_name = job_display_name + gapic_batch_prediction_job.model = model_name + + input_config = gca_bp_job.BatchPredictionJob.InputConfig() + output_config = gca_bp_job.BatchPredictionJob.OutputConfig() + + if bigquery_source: + input_config.instances_format = "bigquery" + input_config.bigquery_source = gca_io.BigQuerySource() + input_config.bigquery_source.input_uri = bigquery_source + else: + input_config.instances_format = instances_format + input_config.gcs_source = gca_io.GcsSource( + uris=gcs_source if type(gcs_source) == list else [gcs_source] + ) + + if bigquery_destination_prefix: + output_config.predictions_format = "bigquery" + output_config.bigquery_destination = gca_io.BigQueryDestination() + + bq_dest_prefix = bigquery_destination_prefix + + if not bq_dest_prefix.startswith("bq://"): + bq_dest_prefix = f"bq://{bq_dest_prefix}" + + output_config.bigquery_destination.output_uri = bq_dest_prefix + else: + output_config.predictions_format = predictions_format + output_config.gcs_destination = gca_io.GcsDestination( + output_uri_prefix=gcs_destination_prefix + ) + + gapic_batch_prediction_job.input_config = input_config + gapic_batch_prediction_job.output_config = output_config + + # Optional Fields + gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + select_version=select_version, + ) + + if model_parameters: + gapic_batch_prediction_job.model_parameters = model_parameters + + # Custom Compute + if machine_type: + + machine_spec = gca_machine_resources.MachineSpec() + machine_spec.machine_type = machine_type + machine_spec.accelerator_type = accelerator_type + machine_spec.accelerator_count = accelerator_count + + dedicated_resources = gca_machine_resources.BatchDedicatedResources() + + dedicated_resources.machine_spec = machine_spec + dedicated_resources.starting_replica_count = starting_replica_count + dedicated_resources.max_replica_count = max_replica_count + + gapic_batch_prediction_job.dedicated_resources = dedicated_resources + + gapic_batch_prediction_job.manual_batch_tuning_parameters = None + + # User Labels + gapic_batch_prediction_job.labels = labels + + # Explanations + if generate_explanation: + gapic_batch_prediction_job.generate_explanation = generate_explanation + + if explanation_metadata or explanation_parameters: + gapic_batch_prediction_job.explanation_spec = gca_explanation_v1beta1.ExplanationSpec( + metadata=explanation_metadata, parameters=explanation_parameters + ) + + # TODO (b/174502913): Support private feature once released + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + return cls._create( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + batch_prediction_job=gapic_batch_prediction_job, + generate_explanation=generate_explanation, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + sync=sync, + ) + + @classmethod + @base.optional_sync() + def _create( + cls, + api_client: job_service_client.JobServiceClient, + parent: str, + batch_prediction_job: Union[ + gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob + ], + generate_explanation: bool, + project: str, + location: str, + credentials: Optional[auth_credentials.Credentials], + sync: bool = True, + ) -> "BatchPredictionJob": + """Create a batch prediction job. + + Args: + api_client (dataset_service_client.DatasetServiceClient): + Required. An instance of DatasetServiceClient with the correct api_endpoint + already set based on user's preferences. + batch_prediction_job (gca_bp_job.BatchPredictionJob): + Required. a batch prediction job proto for creating a batch prediction job on AI Platform. + generate_explanation (bool): + Required. Generate explanation along with the batch prediction + results. + parent (str): + Required. Also known as common location path, that usually contains the + project and location that the user provided to the upstream method. + Example: "projects/my-prj/locations/us-central1" + project (str): + Required. Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Required. Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + + Raises: + ValueError: + If no or multiple source or destinations are provided. Also, if + provided instances_format or predictions_format are not supported + by AI Platform. + + """ + # select v1beta1 if explain else use default v1 + if generate_explanation: + api_client = api_client.select_version(compat.V1BETA1) + + _LOGGER.log_create_with_lro(cls) + + gca_batch_prediction_job = api_client.create_batch_prediction_job( + parent=parent, batch_prediction_job=batch_prediction_job + ) + + batch_prediction_job = cls( + batch_prediction_job_name=gca_batch_prediction_job.name, + project=project, + location=location, + credentials=credentials, + ) + + _LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj") + + _LOGGER.info( + "View Batch Prediction Job:\n%s" % batch_prediction_job._dashboard_uri() + ) + + batch_prediction_job._block_until_complete() + + return batch_prediction_job + + def iter_outputs( + self, bq_max_results: Optional[int] = 100 + ) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: + """Returns an Iterable object to traverse the output files, either a list + of GCS Blobs or a BigQuery RowIterator depending on the output config set + when the BatchPredictionJob was created. + + Args: + bq_max_results: Optional[int] = 100 + Limit on rows to retrieve from prediction table in BigQuery dataset. + Only used when retrieving predictions from a bigquery_destination_prefix. + Default is 100. + + Returns: + Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: + Either a list of GCS Blob objects within the prediction output + directory or an iterable BigQuery RowIterator with predictions. + + Raises: + RuntimeError: + If BatchPredictionJob is in a JobState other than SUCCEEDED, + since outputs cannot be retrieved until the Job has finished. + NotImplementedError: + If BatchPredictionJob succeeded and output_info does not have a + GCS or BQ output provided. + """ + + if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED: + raise RuntimeError( + f"Cannot read outputs until BatchPredictionJob has succeeded, " + f"current state: {self._gca_resource.state}" + ) + + output_info = self._gca_resource.output_info + + # GCS Destination, return Blobs + if output_info.gcs_output_directory: + + # Build a Storage Client using the same credentials as JobServiceClient + storage_client = storage.Client( + credentials=self.api_client._transport._credentials + ) + + gcs_bucket, gcs_prefix = utils.extract_bucket_and_prefix_from_gcs_path( + output_info.gcs_output_directory + ) + + blobs = storage_client.list_blobs(gcs_bucket, prefix=gcs_prefix) + + return blobs + + # BigQuery Destination, return RowIterator + elif output_info.bigquery_output_dataset: + + # Build a BigQuery Client using the same credentials as JobServiceClient + bq_client = bigquery.Client( + credentials=self.api_client._transport._credentials + ) + + # Format from service is `bq://projectId.bqDatasetId` + bq_dataset = output_info.bigquery_output_dataset + + if bq_dataset.startswith("bq://"): + bq_dataset = bq_dataset[5:] + + # # Split project ID and BQ dataset ID + _, bq_dataset_id = bq_dataset.split(".", 1) + + row_iterator = bq_client.list_rows( + table=f"{bq_dataset_id}.predictions", max_results=bq_max_results + ) + + return row_iterator + + # Unknown Destination type + else: + raise NotImplementedError( + f"Unsupported batch prediction output location, here are details" + f"on your prediction output:\n{output_info}" + ) + + +class CustomJob(_Job): + _resource_noun = "customJobs" + _getter_method = "get_custom_job" + _list_method = "list_custom_job" + _cancel_method = "cancel_custom_job" + _delete_method = "delete_custom_job" + _job_type = "training" + pass + + +class DataLabelingJob(_Job): + _resource_noun = "dataLabelingJobs" + _getter_method = "get_data_labeling_job" + _list_method = "list_data_labeling_jobs" + _cancel_method = "cancel_data_labeling_job" + _delete_method = "delete_data_labeling_job" + _job_type = "labeling-tasks" + pass + + +class HyperparameterTuningJob(_Job): + _resource_noun = "hyperparameterTuningJobs" + _getter_method = "get_hyperparameter_tuning_job" + _list_method = "list_hyperparameter_tuning_jobs" + _cancel_method = "cancel_hyperparameter_tuning_job" + _delete_method = "delete_hyperparameter_tuning_job" + pass diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py new file mode 100644 index 0000000000..d96b681695 --- /dev/null +++ b/google/cloud/aiplatform/models.py @@ -0,0 +1,1997 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto +from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import explain +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import models +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.services import endpoint_service_client + +from google.cloud.aiplatform.compat.types import ( + encryption_spec as gca_encryption_spec, + endpoint as gca_endpoint_compat, + endpoint_v1 as gca_endpoint_v1, + endpoint_v1beta1 as gca_endpoint_v1beta1, + explanation_v1beta1 as gca_explanation_v1beta1, + machine_resources as gca_machine_resources_compat, + machine_resources_v1beta1 as gca_machine_resources_v1beta1, + model as gca_model_compat, + model_v1beta1 as gca_model_v1beta1, + env_var as gca_env_var_compat, + env_var_v1beta1 as gca_env_var_v1beta1, +) + +from google.protobuf import json_format + + +_LOGGER = base.Logger(__name__) + + +class Prediction(NamedTuple): + """Prediction class envelopes returned Model predictions and the Model id. + + Attributes: + predictions: + The predictions that are the output of the predictions + call. The schema of any single prediction may be specified via + Endpoint's DeployedModels' [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + deployed_model_id: + ID of the Endpoint's DeployedModel that served this prediction. + explanations: + The explanations of the Model's predictions. It has the same number + of elements as instances to be explained. Default is None. + """ + + predictions: Dict[str, List] + deployed_model_id: str + explanations: Optional[Sequence[gca_explanation_v1beta1.Explanation]] = None + + +class Endpoint(base.AiPlatformResourceNounWithFutureManager): + + client_class = utils.EndpointClientWithOverride + _is_client_prediction_client = False + _resource_noun = "endpoints" + _getter_method = "get_endpoint" + _list_method = "list_endpoints" + _delete_method = "delete_endpoint" + + def __init__( + self, + endpoint_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an endpoint resource. + + Args: + endpoint_name (str): + Required. A fully-qualified endpoint resource name or endpoint ID. + Example: "projects/123/locations/us-central1/endpoints/456" or + "456" when project and location are initialized or passed. + project (str): + Optional. Project to retrieve endpoint from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve endpoint from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=endpoint_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=endpoint_name) + self._prediction_client = self._instantiate_prediction_client( + location=location or initializer.global_config.location, + credentials=credentials, + ) + + @classmethod + def create( + cls, + display_name: str, + description: Optional[str] = None, + labels: Optional[Dict] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync=True, + ) -> "Endpoint": + """Creates a new endpoint. + + Args: + display_name (str): + Required. The user-defined name of the Endpoint. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + project (str): + Required. Project to retrieve endpoint from. If not set, project + set in aiplatform.init will be used. + location (str): + Required. Location to retrieve endpoint from. If not set, location + set in aiplatform.init will be used. + description (str): + Optional. The description of the Endpoint. + labels (Dict): + Optional. The labels with user-defined metadata to + organize your Endpoints. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Endpoint and all sub-resources of this Endpoint will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + endpoint (endpoint.Endpoint): + Created endpoint. + """ + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + utils.validate_display_name(display_name) + + project = project or initializer.global_config.project + location = location or initializer.global_config.location + + return cls._create( + api_client=api_client, + display_name=display_name, + project=project, + location=location, + description=description, + labels=labels, + metadata=metadata, + credentials=credentials, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + @classmethod + @base.optional_sync() + def _create( + cls, + api_client: endpoint_service_client.EndpointServiceClient, + display_name: str, + project: str, + location: str, + description: Optional[str] = None, + labels: Optional[Dict] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, + sync=True, + ) -> "Endpoint": + """ + Creates a new endpoint by calling the API client. + Args: + api_client (EndpointServiceClient): + Required. An instance of EndpointServiceClient with the correct + api_endpoint already set based on user's preferences. + display_name (str): + Required. The user-defined name of the Endpoint. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + project (str): + Required. Project to retrieve endpoint from. If not set, project + set in aiplatform.init will be used. + location (str): + Required. Location to retrieve endpoint from. If not set, location + set in aiplatform.init will be used. + description (str): + Optional. The description of the Endpoint. + labels (Dict): + Optional. The labels with user-defined metadata to + organize your Endpoints. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]): + Optional. The Cloud KMS customer managed encryption key used to protect the dataset. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + sync (bool): + Whether to create this endpoint synchronously. + Returns: + endpoint (endpoint.Endpoint): + Created endpoint. + """ + + parent = initializer.global_config.common_location_path( + project=project, location=location + ) + + gapic_endpoint = gca_endpoint_compat.Endpoint( + display_name=display_name, + description=description, + labels=labels, + encryption_spec=encryption_spec, + ) + + operation_future = api_client.create_endpoint( + parent=parent, endpoint=gapic_endpoint, metadata=metadata + ) + + _LOGGER.log_create_with_lro(cls, operation_future) + + created_endpoint = operation_future.result() + + _LOGGER.log_create_complete(cls, created_endpoint, "endpoint") + + return cls( + endpoint_name=created_endpoint.name, + project=project, + location=location, + credentials=credentials, + ) + + @staticmethod + def _allocate_traffic( + traffic_split: Dict[str, int], traffic_percentage: int, + ) -> Dict[str, int]: + """ + Allocates desired traffic to new deployed model and scales traffic of + older deployed models. + + Args: + traffic_split (Dict[str, int]): + Required. Current traffic split of deployed models in endpoint. + traffic_percentage (int): + Required. Desired traffic to new deployed model. + Returns: + new_traffic_split (Dict[str, int]): + Traffic split to use. + """ + new_traffic_split = {} + old_models_traffic = 100 - traffic_percentage + if old_models_traffic: + unallocated_traffic = old_models_traffic + for deployed_model in traffic_split: + current_traffic = traffic_split[deployed_model] + new_traffic = int(current_traffic / 100 * old_models_traffic) + new_traffic_split[deployed_model] = new_traffic + unallocated_traffic -= new_traffic + # will likely under-allocate. make total 100. + for deployed_model in new_traffic_split: + if unallocated_traffic == 0: + break + new_traffic_split[deployed_model] += 1 + unallocated_traffic -= 1 + + new_traffic_split["0"] = traffic_percentage + + return new_traffic_split + + @staticmethod + def _unallocate_traffic( + traffic_split: Dict[str, int], deployed_model_id: str, + ) -> Dict[str, int]: + """ + Sets deployed model id's traffic to 0 and scales the traffic of other + deployed models. + + Args: + traffic_split (Dict[str, int]): + Required. Current traffic split of deployed models in endpoint. + deployed_model_id (str): + Required. Desired traffic to new deployed model. + Returns: + new_traffic_split (Dict[str, int]): + Traffic split to use. + """ + new_traffic_split = traffic_split.copy() + del new_traffic_split[deployed_model_id] + deployed_model_id_traffic = traffic_split[deployed_model_id] + traffic_percent_left = 100 - deployed_model_id_traffic + + if traffic_percent_left: + unallocated_traffic = 100 + for deployed_model in new_traffic_split: + current_traffic = traffic_split[deployed_model] + new_traffic = int(current_traffic / traffic_percent_left * 100) + new_traffic_split[deployed_model] = new_traffic + unallocated_traffic -= new_traffic + # will likely under-allocate. make total 100. + for deployed_model in new_traffic_split: + if unallocated_traffic == 0: + break + new_traffic_split[deployed_model] += 1 + unallocated_traffic -= 1 + + new_traffic_split[deployed_model_id] = 0 + + return new_traffic_split + + @staticmethod + def _validate_deploy_args( + min_replica_count: int, + max_replica_count: int, + accelerator_type: Optional[str], + deployed_model_display_name: Optional[str], + traffic_split: Optional[Dict[str, int]], + traffic_percentage: int, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + ): + """Helper method to validate deploy arguments. + + Args: + min_replica_count (int): + Required. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Required. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + accelerator_type (str): + Required. Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + deployed_model_display_name (str): + Required. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_split (Dict[str, int]): + Required. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + traffic_percentage (int): + Required. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + + Raises: + ValueError if Min or Max replica is negative. Traffic percentage > 100 or + < 0. Or if traffic_split does not sum to 100. + + ValueError if either explanation_metadata or explanation_parameters + but not both are specified. + """ + if min_replica_count < 0: + raise ValueError("Min replica cannot be negative.") + if max_replica_count < 0: + raise ValueError("Max replica cannot be negative.") + if deployed_model_display_name is not None: + utils.validate_display_name(deployed_model_display_name) + + if traffic_split is None: + if traffic_percentage > 100: + raise ValueError("Traffic percentage cannot be greater than 100.") + if traffic_percentage < 0: + raise ValueError("Traffic percentage cannot be negative.") + + elif traffic_split: + # TODO(b/172678233) verify every referenced deployed model exists + if sum(traffic_split.values()) != 100: + raise ValueError( + "Sum of all traffic within traffic split needs to be 100." + ) + + if bool(explanation_metadata) != bool(explanation_parameters): + raise ValueError( + "Both `explanation_metadata` and `explanation_parameters` should be specified or None." + ) + + # Raises ValueError if invalid accelerator + if accelerator_type: + utils.validate_accelerator_type(accelerator_type) + + def deploy( + self, + model: "Model", + deployed_model_display_name: Optional[str] = None, + traffic_percentage: int = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: int = 1, + max_replica_count: int = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """ + Deploys a Model to the Endpoint. + + Args: + model (aiplatform.Model): + Required. Model to be deployed. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + + self._validate_deploy_args( + min_replica_count, + max_replica_count, + accelerator_type, + deployed_model_display_name, + traffic_split, + traffic_percentage, + explanation_metadata, + explanation_parameters, + ) + + self._deploy( + model=model, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + sync=sync, + ) + + @base.optional_sync() + def _deploy( + self, + model: "Model", + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """ + Deploys a Model to the Endpoint. + + Args: + model (aiplatform.Model): + Required. Model to be deployed. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Raises: + ValueError if there is not current traffic split and traffic percentage + is not 0 or 100. + """ + _LOGGER.log_action_start_against_resource( + f"Deploying Model {model.resource_name} to", "", self + ) + + self._deploy_call( + self.api_client, + self.resource_name, + model.resource_name, + self._gca_resource.traffic_split, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + ) + + _LOGGER.log_action_completed_against_resource("model", "deployed", self) + + self._sync_gca_resource() + + @classmethod + def _deploy_call( + cls, + api_client: endpoint_service_client.EndpointServiceClient, + endpoint_resource_name: str, + model_resource_name: str, + endpoint_resource_traffic_split: Optional[proto.MapField] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + ): + """Helper method to deploy model to endpoint. + + Args: + api_client (endpoint_service_client.EndpointServiceClient): + Required. endpoint_service_client.EndpointServiceClient to make call. + endpoint_resource_name (str): + Required. Endpoint resource name to deploy model to. + model_resource_name (str): + Required. Model resource name of Model to deploy. + endpoint_resource_traffic_split (proto.MapField): + Optional. Endpoint current resource traffic split. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Raises: + ValueError if there is not current traffic split and traffic percentage + is not 0 or 100. + ValueError if only `explanation_metadata` or `explanation_parameters` + is specified. + """ + + max_replica_count = max(min_replica_count, max_replica_count) + + if bool(accelerator_type) != bool(accelerator_count): + raise ValueError( + "Both `accelerator_type` and `accelerator_count` should be specified or None." + ) + + gca_endpoint = gca_endpoint_compat + gca_machine_resources = gca_machine_resources_compat + if explanation_metadata and explanation_parameters: + gca_endpoint = gca_endpoint_v1beta1 + gca_machine_resources = gca_machine_resources_v1beta1 + + if machine_type: + machine_spec = gca_machine_resources.MachineSpec(machine_type=machine_type) + + if accelerator_type and accelerator_count: + utils.validate_accelerator_type(accelerator_type) + machine_spec.accelerator_type = accelerator_type + machine_spec.accelerator_count = accelerator_count + + dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=machine_spec, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + ) + deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=dedicated_resources, + model=model_resource_name, + display_name=deployed_model_display_name, + ) + else: + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=model_resource_name, + display_name=deployed_model_display_name, + ) + + # Service will throw error if both metadata and parameters are not provided + if explanation_metadata and explanation_parameters: + api_client = api_client.select_version(compat.V1BETA1) + explanation_spec = gca_endpoint.explanation.ExplanationSpec() + explanation_spec.metadata = explanation_metadata + explanation_spec.parameters = explanation_parameters + deployed_model.explanation_spec = explanation_spec + + if traffic_split is None: + # new model traffic needs to be 100 if no pre-existing models + if not endpoint_resource_traffic_split: + # default scenario + if traffic_percentage == 0: + traffic_percentage = 100 + # verify user specified 100 + elif traffic_percentage < 100: + raise ValueError( + """There are currently no deployed models so the traffic + percentage for this deployed model needs to be 100.""" + ) + traffic_split = cls._allocate_traffic( + traffic_split=dict(endpoint_resource_traffic_split), + traffic_percentage=traffic_percentage, + ) + + operation_future = api_client.deploy_model( + endpoint=endpoint_resource_name, + deployed_model=deployed_model, + traffic_split=traffic_split, + metadata=metadata, + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Deploy", "model", cls, operation_future + ) + + operation_future.result() + + def undeploy( + self, + deployed_model_id: str, + traffic_split: Optional[Dict[str, int]] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """Undeploys a deployed model. + + Proportionally adjusts the traffic_split among the remaining deployed + models of the endpoint. + + Args: + deployed_model_id (str): + Required. The ID of the DeployedModel to be undeployed from the + Endpoint. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + """ + if traffic_split is not None: + if deployed_model_id in traffic_split and traffic_split[deployed_model_id]: + raise ValueError("Model being undeployed should have 0 traffic.") + if sum(traffic_split.values()) != 100: + # TODO(b/172678233) verify every referenced deployed model exists + raise ValueError( + "Sum of all traffic within traffic split needs to be 100." + ) + + self._undeploy( + deployed_model_id=deployed_model_id, + traffic_split=traffic_split, + metadata=metadata, + sync=sync, + ) + + @base.optional_sync() + def _undeploy( + self, + deployed_model_id: str, + traffic_split: Optional[Dict[str, int]] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """Undeploys a deployed model. + + Proportionally adjusts the traffic_split among the remaining deployed + models of the endpoint. + + Args: + deployed_model_id (str): + Required. The ID of the DeployedModel to be undeployed from the + Endpoint. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + """ + current_traffic_split = traffic_split or dict(self._gca_resource.traffic_split) + + if deployed_model_id in current_traffic_split: + current_traffic_split = self._unallocate_traffic( + traffic_split=current_traffic_split, + deployed_model_id=deployed_model_id, + ) + current_traffic_split.pop(deployed_model_id) + + _LOGGER.log_action_start_against_resource("Undeploying", "model", self) + + operation_future = self.api_client.undeploy_model( + endpoint=self.resource_name, + deployed_model_id=deployed_model_id, + traffic_split=current_traffic_split, + metadata=metadata, + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Undeploy", "model", self.__class__, operation_future + ) + + # block before returning + operation_future.result() + + _LOGGER.log_action_completed_against_resource("model", "undeployed", self) + + # update local resource + self._sync_gca_resource() + + @staticmethod + def _instantiate_prediction_client( + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> utils.PredictionClientWithOverride: + + """Helper method to instantiates prediction client with optional overrides for this endpoint. + + Args: + location (str): The location of this endpoint. + credentials (google.auth.credentials.Credentials): + Optional custom credentials to use when accessing interacting with + the prediction client. + Returns: + prediction_client (prediction_service_client.PredictionServiceClient): + Initalized prediction client with optional overrides. + """ + return initializer.global_config.create_client( + client_class=utils.PredictionClientWithOverride, + credentials=credentials, + location_override=location, + prediction_client=True, + ) + + def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction: + """Make a prediction against this Endpoint. + + Args: + instances (List): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + parameters (Dict): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + Returns: + prediction: Prediction with returned predictions and Model Id. + + """ + self.wait() + + prediction_response = self._prediction_client.predict( + endpoint=self.resource_name, instances=instances, parameters=parameters + ) + + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in prediction_response.predictions.pb + ], + deployed_model_id=prediction_response.deployed_model_id, + ) + + def explain( + self, + instances: List[Dict], + parameters: Optional[Dict] = None, + deployed_model_id: Optional[str] = None, + ) -> Prediction: + """Make a prediction with explanations against this Endpoint. + + Example usage: + response = my_endpoint.explain(instances=[...]) + my_explanations = response.explanations + + Args: + instances (List): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + parameters (Dict): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + deployed_model_id (str): + Optional. If specified, this ExplainRequest will be served by the + chosen DeployedModel, overriding this Endpoint's traffic split. + Returns: + prediction: Prediction with returned predictions, explanations and Model Id. + """ + self.wait() + + explain_response = self._prediction_client.select_version( + compat.V1BETA1 + ).explain( + endpoint=self.resource_name, + instances=instances, + parameters=parameters, + deployed_model_id=deployed_model_id, + ) + + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in explain_response.predictions.pb + ], + deployed_model_id=explain_response.deployed_model_id, + explanations=explain_response.explanations, + ) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["models.Endpoint"]: + """List all Endpoint resource instances. + + Example Usage: + + aiplatform.Endpoint.list( + filter='labels.my_label="my_label_value" OR display_name=!"old_endpoint"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[models.Endpoint] - A list of Endpoint resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + def list_models( + self, + ) -> Sequence[ + Union[gca_endpoint_v1.DeployedModel, gca_endpoint_v1beta1.DeployedModel] + ]: + """Returns a list of the models deployed to this Endpoint. + + Returns: + deployed_models (Sequence[aiplatform.gapic.DeployedModel]): + A list of the models deployed in this Endpoint. + """ + self._sync_gca_resource() + return self._gca_resource.deployed_models + + def undeploy_all(self, sync: bool = True) -> "Endpoint": + """Undeploys every model deployed to this Endpoint. + + Args: + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + self._sync_gca_resource() + + for deployed_model in self._gca_resource.deployed_models: + self._undeploy(deployed_model_id=deployed_model.id, sync=sync) + + return self + + def delete(self, force: bool = False, sync: bool = True) -> None: + """Deletes this AI Platform Endpoint resource. If force is set to True, + all models on this Endpoint will be undeployed prior to deletion. + + Args: + force (bool): + Required. If force is set to True, all deployed models on this + Endpoint will be undeployed first. Default is False. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Raises: + FailedPrecondition: If models are deployed on this Endpoint and force = False. + """ + if force: + self.undeploy_all(sync=sync) + + super().delete(sync=sync) + + +class Model(base.AiPlatformResourceNounWithFutureManager): + + client_class = utils.ModelClientWithOverride + _is_client_prediction_client = False + _resource_noun = "models" + _getter_method = "get_model" + _list_method = "list_models" + _delete_method = "delete_model" + + @property + def uri(self): + """Uri of the model.""" + return self._gca_resource.artifact_uri + + @property + def description(self): + """Description of the model.""" + return self._gca_resource.description + + def __init__( + self, + model_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves the model resource and instantiates its representation. + + Args: + model_name (str): + Required. A fully-qualified model resource name or model ID. + Example: "projects/123/locations/us-central1/models/456" or + "456" when project and location are initialized or passed. + project (str): + Optional project to retrieve model from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve model from. If not set, location + set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials]=None, + Custom credentials to use to upload this model. If not set, + credentials set in aiplatform.init will be used. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=model_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=model_name) + + # TODO(b/170979552) Add support for predict schemata + # TODO(b/170979926) Add support for metadata and metadata schema + @classmethod + @base.optional_sync() + def upload( + cls, + display_name: str, + serving_container_image_uri: str, + *, + artifact_uri: Optional[str] = None, + serving_container_predict_route: Optional[str] = None, + serving_container_health_route: Optional[str] = None, + description: Optional[str] = None, + serving_container_command: Optional[Sequence[str]] = None, + serving_container_args: Optional[Sequence[str]] = None, + serving_container_environment_variables: Optional[Dict[str, str]] = None, + serving_container_ports: Optional[Sequence[int]] = None, + instance_schema_uri: Optional[str] = None, + parameters_schema_uri: Optional[str] = None, + prediction_schema_uri: Optional[str] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync=True, + ) -> "Model": + """Uploads a model and returns a Model representing the uploaded Model resource. + + Example usage: + + my_model = Model.upload( + display_name='my-model', + artifact_uri='gs://my-model/saved-model' + serving_container_image_uri='tensorflow/serving' + ) + + Args: + display_name (str): + Required. The display name of the Model. The name can be up to 128 + characters long and can be consist of any UTF-8 characters. + serving_container_image_uri (str): + Required. The URI of the Model serving container. + artifact_uri (str): + Optional. The path to the directory containing the Model artifact and + any of its supporting files. Leave blank for custom container prediction. + Not present for AutoML Models. + serving_container_predict_route (str): + Optional. An HTTP path to send prediction requests to the container, and + which must be supported by it. If not specified a default HTTP path will + be used by AI Platform. + serving_container_health_route (str): + Optional. An HTTP path to send health check requests to the container, and which + must be supported by it. If not specified a standard HTTP path will be + used by AI Platform. + description (str): + The description of the model. + serving_container_command: Optional[Sequence[str]]=None, + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + serving_container_args: Optional[Sequence[str]]=None, + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + serving_container_environment_variables: Optional[Dict[str, str]]=None, + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + serving_container_ports: Optional[Sequence[int]]=None, + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + project: Optional[str]=None, + Project to upload this model to. Overrides project set in + aiplatform.init. + location: Optional[str]=None, + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials: Optional[auth_credentials.Credentials]=None, + Custom credentials to use to upload this model. Overrides credentials + set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Returns: + model: Instantiated representation of the uploaded model resource. + Raises: + ValueError if only `explanation_metadata` or `explanation_parameters` + is specified. + """ + utils.validate_display_name(display_name) + + if bool(explanation_metadata) != bool(explanation_parameters): + raise ValueError( + "Both `explanation_metadata` and `explanation_parameters` should be specified or None." + ) + + gca_endpoint = gca_endpoint_compat + gca_model = gca_model_compat + gca_env_var = gca_env_var_compat + if explanation_metadata and explanation_parameters: + gca_endpoint = gca_endpoint_v1beta1 + gca_model = gca_model_v1beta1 + gca_env_var = gca_env_var_v1beta1 + + api_client = cls._instantiate_client(location, credentials) + env = None + ports = None + + if serving_container_environment_variables: + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in serving_container_environment_variables.items() + ] + if serving_container_ports: + ports = [ + gca_model.Port(container_port=port) for port in serving_container_ports + ] + + container_spec = gca_model.ModelContainerSpec( + image_uri=serving_container_image_uri, + command=serving_container_command, + args=serving_container_args, + env=env, + ports=ports, + predict_route=serving_container_predict_route, + health_route=serving_container_health_route, + ) + + model_predict_schemata = None + if any([instance_schema_uri, parameters_schema_uri, prediction_schema_uri]): + model_predict_schemata = gca_model.PredictSchemata( + instance_schema_uri=instance_schema_uri, + parameters_schema_uri=parameters_schema_uri, + prediction_schema_uri=prediction_schema_uri, + ) + + # TODO(b/182388545) initializer.global_config.get_encryption_spec from a sync function + encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + ) + + managed_model = gca_model.Model( + display_name=display_name, + description=description, + container_spec=container_spec, + predict_schemata=model_predict_schemata, + encryption_spec=encryption_spec, + ) + + if artifact_uri: + managed_model.artifact_uri = artifact_uri + + # Override explanation_spec if both required fields are provided + if explanation_metadata and explanation_parameters: + api_client = api_client.select_version(compat.V1BETA1) + explanation_spec = gca_endpoint.explanation.ExplanationSpec() + explanation_spec.metadata = explanation_metadata + explanation_spec.parameters = explanation_parameters + managed_model.explanation_spec = explanation_spec + + lro = api_client.upload_model( + parent=initializer.global_config.common_location_path(project, location), + model=managed_model, + ) + + _LOGGER.log_create_with_lro(cls, lro) + + model_upload_response = lro.result() + + this_model = cls(model_upload_response.model) + + _LOGGER.log_create_complete(cls, this_model._gca_resource, "model") + + return this_model + + # TODO(b/172502059) support deploying with endpoint resource name + def deploy( + self, + endpoint: Optional["Endpoint"] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync=True, + ) -> Endpoint: + """ + Deploys model to endpoint. Endpoint will be created if unspecified. + + Args: + endpoint ("Endpoint"): + Optional. Endpoint to deploy model to. If not specified, endpoint + display name will be model display name+'_endpoint'. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the smaller value of min_replica_count or 1 will + be used. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + endpoint ("Endpoint"): + Endpoint with the deployed model. + + """ + + Endpoint._validate_deploy_args( + min_replica_count, + max_replica_count, + accelerator_type, + deployed_model_display_name, + traffic_split, + traffic_percentage, + explanation_metadata, + explanation_parameters, + ) + + return self._deploy( + endpoint=endpoint, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + encryption_spec_key_name=encryption_spec_key_name + or initializer.global_config.encryption_spec_key_name, + sync=sync, + ) + + @base.optional_sync(return_input_arg="endpoint", bind_future_to_self=False) + def _deploy( + self, + endpoint: Optional["Endpoint"] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> Endpoint: + """ + Deploys model to endpoint. Endpoint will be created if unspecified. + + Args: + endpoint ("Endpoint"): + Optional. Endpoint to deploy model to. If not specified, endpoint + display name will be model display name+'_endpoint'. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the smaller value of min_replica_count or 1 will + be used. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + endpoint ("Endpoint"): + Endpoint with the deployed model. + """ + + if endpoint is None: + display_name = self.display_name[:118] + "_endpoint" + endpoint = Endpoint.create( + display_name=display_name, + project=self.project, + location=self.location, + credentials=self.credentials, + encryption_spec_key_name=encryption_spec_key_name, + ) + + _LOGGER.log_action_start_against_resource("Deploying model to", "", endpoint) + + Endpoint._deploy_call( + endpoint.api_client, + endpoint.resource_name, + self.resource_name, + endpoint._gca_resource.traffic_split, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + ) + + _LOGGER.log_action_completed_against_resource("model", "deployed", endpoint) + + endpoint._sync_gca_resource() + + return endpoint + + def batch_predict( + self, + job_display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bigquery_source: Optional[str] = None, + instances_format: str = "jsonl", + gcs_destination_prefix: Optional[str] = None, + bigquery_destination_prefix: Optional[str] = None, + predictions_format: str = "jsonl", + model_parameters: Optional[Dict] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + starting_replica_count: Optional[int] = None, + max_replica_count: Optional[int] = None, + generate_explanation: Optional[bool] = False, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + labels: Optional[dict] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> jobs.BatchPredictionJob: + """Creates a batch prediction job using this Model and outputs prediction + results to the provided destination prefix in the specified + `predictions_format`. One source and one destination prefix are required. + + Example usage: + + my_model.batch_predict( + job_display_name="prediction-123", + gcs_source="gs://example-bucket/instances.csv", + instances_format="csv", + bigquery_destination_prefix="projectId.bqDatasetId.bqTableId" + ) + + Args: + job_display_name (str): + Required. The user-defined name of the BatchPredictionJob. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source: Optional[Sequence[str]] = None + Google Cloud Storage URI(-s) to your instances to run + batch prediction on. They must match `instances_format`. + May contain wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + bigquery_source: Optional[str] = None + BigQuery URI to a table, up to 2000 characters long. For example: + `projectId.bqDatasetId.bqTableId` + instances_format: str = "jsonl" + Required. The format in which instances are given, must be one + of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip", + or "file-list". Default is "jsonl" when using `gcs_source`. If a + `bigquery_source` is provided, this is overriden to "bigquery". + gcs_destination_prefix: Optional[str] = None + The Google Cloud Storage location of the directory where the + output is to be written to. In the given directory a new + directory is created. Its name is + ``prediction--``, where + timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. + Inside of it files ``predictions_0001.``, + ``predictions_0002.``, ..., + ``predictions_N.`` are created where + ```` depends on chosen ``predictions_format``, + and N may equal 0001 and depends on the total number of + successfully predicted instances. If the Model has both + ``instance`` and ``prediction`` schemata defined then each such + file contains predictions as per the ``predictions_format``. + If prediction for any instance failed (partially or + completely), then an additional ``errors_0001.``, + ``errors_0002.``,..., ``errors_N.`` + files are created (N depends on total number of failed + predictions). These files contain the failed instances, as + per their schema, followed by an additional ``error`` field + which as value has ```google.rpc.Status`` `__ + containing only ``code`` and ``message`` fields. + bigquery_destination_prefix: Optional[str] = None + The BigQuery project location where the output is to be + written to. In the given project a new dataset is created + with name + ``prediction__`` where + is made BigQuery-dataset-name compatible (for example, most + special characters become underscores), and timestamp is in + YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the + dataset two tables will be created, ``predictions``, and + ``errors``. If the Model has both ``instance`` and ``prediction`` + schemata defined then the tables have columns as follows: + The ``predictions`` table contains instances for which the + prediction succeeded, it has columns as per a concatenation + of the Model's instance and prediction schemata. The + ``errors`` table contains rows for which the prediction has + failed, it has instance columns, as per the instance schema, + followed by a single "errors" column, which as values has + ```google.rpc.Status`` `__ represented as a STRUCT, + and containing only ``code`` and ``message``. + predictions_format: str = "jsonl" + Required. The format in which AI Platform gives the + predictions, must be one of "jsonl", "csv", or "bigquery". + Default is "jsonl" when using `gcs_destination_prefix`. If a + `bigquery_destination_prefix` is provided, this is overriden to + "bigquery". + model_parameters: Optional[Dict] = None + Optional. The parameters that govern the predictions. The schema of + the parameters may be specified via the Model's `parameters_schema_uri`. + machine_type: Optional[str] = None + Optional. The type of machine for running batch prediction on + dedicated resources. Not specifying machine type will result in + batch prediction job being run with automatic resources. + accelerator_type: Optional[str] = None + Optional. The type of accelerator(s) that may be attached + to the machine as per `accelerator_count`. Only used if + `machine_type` is set. + accelerator_count: Optional[int] = None + Optional. The number of accelerators to attach to the + `machine_type`. Only used if `machine_type` is set. + starting_replica_count: Optional[int] = None + The number of machine replicas used at the start of the batch + operation. If not set, AI Platform decides starting number, not + greater than `max_replica_count`. Only used if `machine_type` is + set. + max_replica_count: Optional[int] = None + The maximum number of machine replicas the batch operation may + be scaled to. Only used if `machine_type` is set. + Default is 10. + generate_explanation (bool): + Optional. Generate explanation along with the batch prediction + results. This will cause the batch prediction output to include + explanations based on the `prediction_format`: + - `bigquery`: output includes a column named `explanation`. The value + is a struct that conforms to the [aiplatform.gapic.Explanation] object. + - `jsonl`: The JSON objects on each line include an additional entry + keyed `explanation`. The value of the entry is a JSON object that + conforms to the [aiplatform.gapic.Explanation] object. + - `csv`: Generating explanations for CSV format is not supported. + explanation_metadata (explain.ExplanationMetadata): + Optional. Explanation metadata configuration for this BatchPredictionJob. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_metadata`. + All fields of `explanation_metadata` are optional in the request. If + a field of the `explanation_metadata` object is not populated, the + corresponding field of the `Model.explanation_metadata` object is inherited. + For more details, see `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_parameters`. + All fields of `explanation_parameters` are optional in the request. If + a field of the `explanation_parameters` object is not populated, the + corresponding field of the `Model.explanation_parameters` object is inherited. + For more details, see `Ref docs ` + labels: Optional[dict] = None + Optional. The labels with user-defined metadata to organize your + BatchPredictionJobs. Label keys and values can be no longer than + 64 characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information and examples of labels. + credentials: Optional[auth_credentials.Credentials] = None + Optional. Custom credentials to use to create this batch prediction + job. Overrides credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + + """ + self.wait() + + return jobs.BatchPredictionJob.create( + job_display_name=job_display_name, + model_name=self.resource_name, + instances_format=instances_format, + predictions_format=predictions_format, + gcs_source=gcs_source, + bigquery_source=bigquery_source, + gcs_destination_prefix=gcs_destination_prefix, + bigquery_destination_prefix=bigquery_destination_prefix, + model_parameters=model_parameters, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + starting_replica_count=starting_replica_count, + max_replica_count=max_replica_count, + generate_explanation=generate_explanation, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + labels=labels, + project=self.project, + location=self.location, + credentials=credentials or self.credentials, + encryption_spec_key_name=encryption_spec_key_name, + sync=sync, + ) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["models.Model"]: + """List all Model resource instances. + + Example Usage: + + aiplatform.Model.list( + filter='labels.my_label="my_label_value" AND display_name="my_model"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[models.Model] - A list of Model resource objects + """ + + return cls._list( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/schema.py b/google/cloud/aiplatform/schema.py new file mode 100644 index 0000000000..04d2f026a1 --- /dev/null +++ b/google/cloud/aiplatform/schema.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Namespaced AI Platform Schemas.""" + + +class training_job: + class definition: + custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml" + automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml" + automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml" + automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml" + automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml" + automl_text_extraction = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml" + automl_text_sentiment = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_sentiment_1.0.0.yaml" + automl_video_action_recognition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_action_recognition_1.0.0.yaml" + automl_video_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_classification_1.0.0.yaml" + automl_video_object_tracking = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml" + + +class dataset: + class metadata: + tabular = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml" + ) + image = "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml" + text = "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml" + video = "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml" + + class ioformat: + class image: + multi_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_classification_multi_label_io_format_1.0.0.yaml" + single_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_classification_single_label_io_format_1.0.0.yaml" + bounding_box = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" + image_segmentation = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_segmentation_io_format_1.0.0.yaml" + + class text: + multi_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_classification_multi_label_io_format_1.0.0.yaml" + single_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_classification_single_label_io_format_1.0.0.yaml" + extraction = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml" + sentiment = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_sentiment_io_format_1.0.0.yaml" + + class video: + action_recognition = "gs://google-cloud-aiplatform/schema/dataset/ioformat/video_action_recognition_io_format_1.0.0.yaml" + classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/video_classification_io_format_1.0.0.yaml" + object_tracking = "gs://google-cloud-aiplatform/schema/dataset/ioformat/video_object_tracking_io_format_1.0.0.yaml" + + class annotation: + class image: + classification = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml" + bounding_box = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_bounding_box_1.0.0.yaml" + segmentation = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_segmentation_1.0.0.yaml" + + class text: + classification = "gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml" + extraction = "gs://google-cloud-aiplatform/schema/dataset/annotation/text_extraction_1.0.0.yaml" + sentiment = "gs://google-cloud-aiplatform/schema/dataset/annotation/text_sentiment_1.0.0.yaml" + + class video: + classification = "gs://google-cloud-aiplatform/schema/dataset/annotation/video_classification_1.0.0.yaml" + object_tracking = "gs://google-cloud-aiplatform/schema/dataset/annotation/video_object_tracking_1.0.0.yaml" + action_recognition = "gs://google-cloud-aiplatform/schema/dataset/annotation/video_action_recognition_1.0.0.yaml" diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py new file mode 100644 index 0000000000..220a34637e --- /dev/null +++ b/google/cloud/aiplatform/training_jobs.py @@ -0,0 +1,4362 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import functools +import logging +import pathlib +import shutil +import subprocess +import sys +import tempfile +import time +from typing import Callable, Dict, List, Optional, NamedTuple, Sequence, Tuple, Union + +import abc + +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.types import ( + accelerator_type as gca_accelerator_type, + env_var as gca_env_var, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +from google.cloud.aiplatform.v1.schema.trainingjob import ( + definition_v1 as training_job_inputs, +) + +from google.cloud import storage +from google.rpc import code_pb2 + +import proto + + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) +_LOGGER = base.Logger(__name__) + +_PIPELINE_COMPLETE_STATES = set( + [ + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + gca_pipeline_state.PipelineState.PIPELINE_STATE_CANCELLED, + gca_pipeline_state.PipelineState.PIPELINE_STATE_PAUSED, + ] +) + + +class _TrainingJob(base.AiPlatformResourceNounWithFutureManager): + + client_class = utils.PipelineClientWithOverride + _is_client_prediction_client = False + _resource_noun = "trainingPipelines" + _getter_method = "get_training_pipeline" + _list_method = "list_training_pipelines" + _delete_method = "delete_training_pipeline" + + def __init__( + self, + display_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + project (str): + Optional project to retrieve model from. If not set, project set in + aiplatform.init will be used. + location (str): + Optional location to retrieve model from. If not set, location set in + aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional credentials to use to retrieve the model. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + """ + utils.validate_display_name(display_name) + + super().__init__(project=project, location=location, credentials=credentials) + self._display_name = display_name + self._project = project + self._training_encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=training_encryption_spec_key_name + ) + self._model_encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=model_encryption_spec_key_name + ) + self._gca_resource = None + + @property + @classmethod + @abc.abstractmethod + def _supported_training_schemas(cls) -> Tuple[str]: + """List of supported schemas for this training job""" + + pass + + @classmethod + def get( + cls, + resource_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "_TrainingJob": + """Get Training Job for the given resource_name. + + Args: + resource_name (str): + Required. A fully-qualified resource name or ID. + project (str): + Optional project to retrieve dataset from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve dataset from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + Raises: + ValueError: If the retrieved training job's training task definition + doesn't match the custom training task definition. + + Returns: + An AI Platform Training Job + """ + + # Create job with dummy parameters + # These parameters won't be used as user can not run the job again. + # If they try, an exception will be raised. + self = cls._empty_constructor( + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, + ) + + self._gca_resource = self._get_gca_resource(resource_name=resource_name) + + if ( + self._gca_resource.training_task_definition + not in cls._supported_training_schemas + ): + raise ValueError( + f"The retrieved job's training task definition " + f"is {self._gca_resource.training_task_definition}, " + f"which is not compatible with {cls.__name__}." + ) + + return self + + @property + @abc.abstractmethod + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + + pass + + @abc.abstractmethod + def run(self) -> Optional[models.Model]: + """Runs the training job. Should call _run_job internally""" + pass + + @staticmethod + def _create_input_data_config( + dataset: Optional[datasets._Dataset] = None, + annotation_schema_uri: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + gcs_destination_uri_prefix: Optional[str] = None, + bigquery_destination: Optional[str] = None, + ) -> Optional[gca_training_pipeline.InputDataConfig]: + """Constructs a input data config to pass to the training pipeline. + + Args: + dataset (datasets._Dataset): + The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + gcs_destination_uri_prefix (str): + Optional. The Google Cloud Storage location. + + The AI Platform environment variables representing Google + Cloud Storage data URIs will always be represented in the + Google Cloud Storage wildcard format to support sharded + data. + + - AIP_DATA_FORMAT = "jsonl". + - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" + - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" + - AIP_TEST_DATA_URI = "gcs_destination/test-*". + bigquery_destination (str): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + """ + + input_data_config = None + if dataset: + # Create fraction split spec + fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + ) + + # Create predefined split spec + predefined_split = None + if predefined_split_column_name: + if ( + dataset._gca_resource.metadata_schema_uri + != schema.dataset.metadata.tabular + ): + raise ValueError( + "A pre-defined split may only be used with a tabular Dataset" + ) + + predefined_split = gca_training_pipeline.PredefinedSplit( + key=predefined_split_column_name + ) + + # Create GCS destination + gcs_destination = None + if gcs_destination_uri_prefix: + gcs_destination = gca_io.GcsDestination( + output_uri_prefix=gcs_destination_uri_prefix + ) + + # TODO(b/177416223) validate managed BQ dataset is passed in + bigquery_destination_proto = None + if bigquery_destination: + bigquery_destination_proto = gca_io.BigQueryDestination( + output_uri=bigquery_destination + ) + + # create input data config + input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=fraction_split, + predefined_split=predefined_split, + dataset_id=dataset.name, + annotation_schema_uri=annotation_schema_uri, + gcs_destination=gcs_destination, + bigquery_destination=bigquery_destination_proto, + ) + + return input_data_config + + def _run_job( + self, + training_task_definition: str, + training_task_inputs: Union[dict, proto.Message], + dataset: Optional[datasets._Dataset], + training_fraction_split: float, + validation_fraction_split: float, + test_fraction_split: float, + annotation_schema_uri: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + model: Optional[gca_model.Model] = None, + gcs_destination_uri_prefix: Optional[str] = None, + bigquery_destination: Optional[str] = None, + ) -> Optional[models.Model]: + """Runs the training job. + + Args: + training_task_definition (str): + Required. A Google Cloud Storage path to the + YAML file that defines the training task which + is responsible for producing the model artifact, + and may also include additional auxiliary work. + The definition files that can be used here are + found in gs://google-cloud- + aiplatform/schema/trainingjob/definition/. Note: + The URI given on output will be immutable and + probably different, including the URI scheme, + than the one given on input. The output URI will + point to a location where the user only has a + read access. + training_task_inputs (Union[dict, proto.Message]): + Required. The training task's input that corresponds to the training_task_definition parameter. + dataset (datasets._Dataset): + The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + model (~.model.Model): + Optional. Describes the Model that may be uploaded (via + [ModelService.UploadMode][]) by this TrainingPipeline. The + TrainingPipeline's + ``training_task_definition`` + should make clear whether this Model description should be + populated, and if there are any special requirements + regarding how it should be filled. If nothing is mentioned + in the + ``training_task_definition``, + then it should be assumed that this field should not be + filled and the training task either uploads the Model + without a need of this information, or that training task + does not support uploading a Model as part of the pipeline. + When the Pipeline's state becomes + ``PIPELINE_STATE_SUCCEEDED`` and the trained Model had been + uploaded into AI Platform, then the model_to_upload's + resource ``name`` + is populated. The Model is always uploaded into the Project + and Location in which this pipeline is. + gcs_destination_uri_prefix (str): + Optional. The Google Cloud Storage location. + + The AI Platform environment variables representing Google + Cloud Storage data URIs will always be represented in the + Google Cloud Storage wildcard format to support sharded + data. + + - AIP_DATA_FORMAT = "jsonl". + - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" + - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" + - AIP_TEST_DATA_URI = "gcs_destination/test-*". + bigquery_destination (str): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + """ + + input_data_config = self._create_input_data_config( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + gcs_destination_uri_prefix=gcs_destination_uri_prefix, + bigquery_destination=bigquery_destination, + ) + + # create training pipeline + training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=self._display_name, + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs, + model_to_upload=model, + input_data_config=input_data_config, + encryption_spec=self._training_encryption_spec, + ) + + training_pipeline = self.api_client.create_training_pipeline( + parent=initializer.global_config.common_location_path( + self.project, self.location + ), + training_pipeline=training_pipeline, + ) + + self._gca_resource = training_pipeline + + _LOGGER.info("View Training:\n%s" % self._dashboard_uri()) + + model = self._get_model() + + if model is None: + _LOGGER.warning( + "Training did not produce a Managed Model returning None. " + + self._model_upload_fail_string + ) + + return model + + def _is_waiting_to_run(self) -> bool: + """Returns True if the Job is pending on upstream tasks False otherwise.""" + self._raise_future_exception() + if self._latest_future: + _LOGGER.info( + "Training Job is waiting for upstream SDK tasks to complete before" + " launching." + ) + return True + return False + + @property + def state(self) -> Optional[gca_pipeline_state.PipelineState]: + """Current training state.""" + + if self._assert_has_run(): + return + + self._sync_gca_resource() + return self._gca_resource.state + + def get_model(self, sync=True) -> models.Model: + """AI Platform Model produced by this training, if one was produced. + + Args: + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: AI Platform Model produced by this training + + Raises: + RuntimeError if training failed or if a model was not produced by this training. + """ + + self._assert_has_run() + if not self._gca_resource.model_to_upload: + raise RuntimeError(self._model_upload_fail_string) + + return self._force_get_model(sync=sync) + + @base.optional_sync() + def _force_get_model(self, sync: bool = True) -> models.Model: + """AI Platform Model produced by this training, if one was produced. + + Args: + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: AI Platform Model produced by this training + + Raises: + RuntimeError if training failed or if a model was not produced by this training. + """ + model = self._get_model() + + if model is None: + raise RuntimeError(self._model_upload_fail_string) + + return model + + def _get_model(self) -> Optional[models.Model]: + """Helper method to get and instantiate the Model to Upload. + + Returns: + model: AI Platform Model if training succeeded and produced an AI Platform + Model. None otherwise. + + Raises: + RuntimeError if Training failed. + """ + self._block_until_complete() + + if self.has_failed: + raise RuntimeError( + f"Training Pipeline {self.resource_name} failed. No model available." + ) + + if not self._gca_resource.model_to_upload: + return None + + if self._gca_resource.model_to_upload.name: + fields = utils.extract_fields_from_resource_name( + self._gca_resource.model_to_upload.name + ) + + return models.Model( + fields.id, project=fields.project, location=fields.location, + ) + + def _block_until_complete(self): + """Helper method to block and check on job until complete.""" + + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + log_wait = 5 + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + previous_time = time.time() + while self.state not in _PIPELINE_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + time.sleep(wait) + + self._raise_failure() + + _LOGGER.log_action_completed_against_resource("run", "completed", self) + + if self._gca_resource.model_to_upload and not self.has_failed: + _LOGGER.info( + "Model available at %s" % self._gca_resource.model_to_upload.name + ) + + def _raise_failure(self): + """Helper method to raise failure if TrainingPipeline fails. + + Raises: + RuntimeError: If training failed.""" + + if self._gca_resource.error.code != code_pb2.OK: + raise RuntimeError("Training failed with:\n%s" % self._gca_resource.error) + + @property + def has_failed(self) -> bool: + """Returns True if training has failed. False otherwise.""" + self._assert_has_run() + return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED + + def _dashboard_uri(self) -> str: + """Helper method to compose the dashboard uri where training can be viewed.""" + fields = utils.extract_fields_from_resource_name(self.resource_name) + url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/training/{fields.id}?project={fields.project}" + return url + + def _sync_gca_resource(self): + """Helper method to sync the local gca_source against the service.""" + self._gca_resource = self.api_client.get_training_pipeline( + name=self.resource_name + ) + + @property + def _has_run(self) -> bool: + """Helper property to check if this training job has been run.""" + return self._gca_resource is not None + + def _assert_has_run(self) -> bool: + """Helper method to assert that this training has run.""" + if not self._has_run: + if self._is_waiting_to_run(): + return True + raise RuntimeError( + "TrainingPipeline has not been launched. You must run this" + " TrainingPipeline using TrainingPipeline.run. " + ) + return False + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["base.AiPlatformResourceNoune"]: + """List all instances of this TrainingJob resource. + + Example Usage: + + aiplatform.CustomTrainingJob.list( + filter='display_name="experiment_a27"', + order_by='create_time desc' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of TrainingJob resource objects + """ + + training_job_subclass_filter = ( + lambda gapic_obj: gapic_obj.training_task_definition + in cls._supported_training_schemas + ) + + return cls._list_with_local_order( + cls_filter=training_job_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + def cancel(self) -> None: + """Starts asynchronous cancellation on the TrainingJob. The server + makes a best effort to cancel the job, but success is not guaranteed. + On successful cancellation, the TrainingJob is not deleted; instead it + becomes a job with state set to `CANCELLED`. + + Raises: + RuntimeError if this TrainingJob has not started running. + """ + if not self._has_run: + raise RuntimeError( + "This TrainingJob has not been launched, use the `run()` method " + "to start. `cancel()` can only be called on a job that is running." + ) + self.api_client.cancel_training_pipeline(name=self.resource_name) + + +def _timestamped_gcs_dir(root_gcs_path: str, dir_name_prefix: str) -> str: + """Composes a timestamped GCS directory. + + Args: + root_gcs_path: GCS path to put the timestamped directory. + dir_name_prefix: Prefix to add the timestamped directory. + Returns: + Timestamped gcs directory path in root_gcs_path. + """ + timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds") + dir_name = "-".join([dir_name_prefix, timestamp]) + if root_gcs_path.endswith("/"): + root_gcs_path = root_gcs_path[:-1] + gcs_path = "/".join([root_gcs_path, dir_name]) + if not gcs_path.startswith("gs://"): + return "gs://" + gcs_path + return gcs_path + + +def _timestamped_copy_to_gcs( + local_file_path: str, + gcs_dir: str, + project: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, +) -> str: + """Copies a local file to a GCS path. + + The file copied to GCS is the name of the local file prepended with an + "aiplatform-{timestamp}-" string. + + Args: + local_file_path (str): Required. Local file to copy to GCS. + gcs_dir (str): + Required. The GCS directory to copy to. + project (str): + Project that contains the staging bucket. Default will be used if not + provided. Model Builder callers should pass this in. + credentials (auth_credentials.Credentials): + Custom credentials to use with bucket. Model Builder callers should pass + this in. + Returns: + gcs_path (str): The path of the copied file in gcs. + """ + + gcs_bucket, gcs_blob_prefix = utils.extract_bucket_and_prefix_from_gcs_path(gcs_dir) + + local_file_name = pathlib.Path(local_file_path).name + timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds") + blob_path = "-".join(["aiplatform", timestamp, local_file_name]) + + if gcs_blob_prefix: + blob_path = "/".join([gcs_blob_prefix, blob_path]) + + # TODO(b/171202993) add user agent + client = storage.Client(project=project, credentials=credentials) + bucket = client.bucket(gcs_bucket) + blob = bucket.blob(blob_path) + blob.upload_from_filename(local_file_path) + + gcs_path = "".join(["gs://", "/".join([blob.bucket.name, blob.name])]) + return gcs_path + + +def _get_python_executable() -> str: + """Returns Python executable. + + Raises: + EnvironmentError if Python executable is not found. + Returns: + Python executable to use for setuptools packaging. + """ + + python_executable = sys.executable + + if not python_executable: + raise EnvironmentError("Cannot find Python executable for packaging.") + return python_executable + + +class _TrainingScriptPythonPackager: + """Converts a Python script into Python package suitable for aiplatform training. + + Copies the script to specified location. + + Class Attributes: + _TRAINER_FOLDER: Constant folder name to build package. + _ROOT_MODULE: Constant root name of module. + _TEST_MODULE_NAME: Constant name of module that will store script. + _SETUP_PY_VERSION: Constant version of this created python package. + _SETUP_PY_TEMPLATE: Constant template used to generate setup.py file. + _SETUP_PY_SOURCE_DISTRIBUTION_CMD: + Constant command to generate the source distribution package. + + Attributes: + script_path: local path of script to package + requirements: list of Python dependencies to add to package + + Usage: + + packager = TrainingScriptPythonPackager('my_script.py', ['pandas', 'pytorch']) + gcs_path = packager.package_and_copy_to_gcs( + gcs_staging_dir='my-bucket', + project='my-prject') + module_name = packager.module_name + + The package after installed can be executed as: + python -m aiplatform_custom_trainer_script.task + + """ + + _TRAINER_FOLDER = "trainer" + _ROOT_MODULE = "aiplatform_custom_trainer_script" + _TASK_MODULE_NAME = "task" + _SETUP_PY_VERSION = "0.1" + + _SETUP_PY_TEMPLATE = """from setuptools import find_packages +from setuptools import setup + +setup( + name='{name}', + version='{version}', + packages=find_packages(), + install_requires=({requirements}), + include_package_data=True, + description='My training application.' +)""" + + _SETUP_PY_SOURCE_DISTRIBUTION_CMD = "setup.py sdist --formats=gztar" + + # Module name that can be executed during training. ie. python -m + module_name = f"{_ROOT_MODULE}.{_TASK_MODULE_NAME}" + + def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = None): + """Initializes packager. + + Args: + script_path (str): Required. Local path to script. + requirements (Sequence[str]): + List of python packages dependencies of script. + """ + + self.script_path = script_path + self.requirements = requirements or [] + + def make_package(self, package_directory: str) -> str: + """Converts script into a Python package suitable for python module execution. + + Args: + package_directory (str): Directory to build package in. + Returns: + source_distribution_path (str): Path to built package. + Raises: + RunTimeError if package creation fails. + """ + # The root folder to builder the package in + package_path = pathlib.Path(package_directory) + + # Root directory of the package + trainer_root_path = package_path / self._TRAINER_FOLDER + + # The root module of the python package + trainer_path = trainer_root_path / self._ROOT_MODULE + + # __init__.py path in root module + init_path = trainer_path / "__init__.py" + + # The module that will contain the script + script_out_path = trainer_path / f"{self._TASK_MODULE_NAME}.py" + + # The path to setup.py in the package. + setup_py_path = trainer_root_path / "setup.py" + + # The path to the generated source distribution. + source_distribution_path = ( + trainer_root_path + / "dist" + / f"{self._ROOT_MODULE}-{self._SETUP_PY_VERSION}.tar.gz" + ) + + trainer_root_path.mkdir() + trainer_path.mkdir() + + # Make empty __init__.py + with init_path.open("w"): + pass + + # Format the setup.py file. + setup_py_output = self._SETUP_PY_TEMPLATE.format( + name=self._ROOT_MODULE, + requirements=",".join(f'"{r}"' for r in self.requirements), + version=self._SETUP_PY_VERSION, + ) + + # Write setup.py + with setup_py_path.open("w") as fp: + fp.write(setup_py_output) + + # Copy script as module of python package. + shutil.copy(self.script_path, script_out_path) + + # Run setup.py to create the source distribution. + setup_cmd = [ + _get_python_executable() + ] + self._SETUP_PY_SOURCE_DISTRIBUTION_CMD.split() + + p = subprocess.Popen( + args=setup_cmd, + cwd=trainer_root_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + output, error = p.communicate() + + # Raise informative error if packaging fails. + if p.returncode != 0: + raise RuntimeError( + "Packaging of training script failed with code %d\n%s \n%s" + % (p.returncode, output.decode(), error.decode()) + ) + + return str(source_distribution_path) + + def package_and_copy(self, copy_method: Callable[[str], str]) -> str: + """Packages the script and executes copy with given copy_method. + + Args: + copy_method Callable[[str], str] + Takes a string path, copies to a desired location, and returns the + output path location. + Returns: + output_path str: Location of copied package. + """ + + with tempfile.TemporaryDirectory() as tmpdirname: + source_distribution_path = self.make_package(tmpdirname) + output_location = copy_method(source_distribution_path) + _LOGGER.info("Training script copied to:\n%s." % output_location) + return output_location + + def package_and_copy_to_gcs( + self, + gcs_staging_dir: str, + project: str = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> str: + """Packages script in Python package and copies package to GCS bucket. + + Args + gcs_staging_dir (str): Required. GCS Staging directory. + project (str): Required. Project where GCS Staging bucket is located. + credentials (auth_credentials.Credentials): + Optional credentials used with GCS client. + Returns: + GCS location of Python package. + """ + + copy_method = functools.partial( + _timestamped_copy_to_gcs, + gcs_dir=gcs_staging_dir, + project=project, + credentials=credentials, + ) + return self.package_and_copy(copy_method=copy_method) + + +class _MachineSpec(NamedTuple): + """Specification container for Machine specs used for distributed training. + + Usage: + + spec = _MachineSpec( + replica_count=10, + machine_type='n1-standard-4', + accelerator_count=2, + accelerator_type='NVIDIA_TESLA_K80') + + Note that container and python package specs are not stored with this spec. + """ + + replica_count: int = 0 + machine_type: str = "n1-standard-4" + accelerator_count: int = 0 + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED" + + def _get_accelerator_type(self) -> Optional[str]: + """Validates accelerator_type and returns the name of the accelerator. + + Returns: + None if no accelerator or valid accelerator name. + + Raise: + ValueError if accelerator type is invalid. + """ + + # Raises ValueError if invalid accelerator_type + utils.validate_accelerator_type(self.accelerator_type) + + accelerator_enum = getattr( + gca_accelerator_type.AcceleratorType, self.accelerator_type + ) + + if ( + accelerator_enum + != gca_accelerator_type.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED + ): + return self.accelerator_type + + @property + def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]: + """Return specification as a Dict.""" + spec = { + "machineSpec": {"machineType": self.machine_type}, + "replicaCount": self.replica_count, + } + accelerator_type = self._get_accelerator_type() + if accelerator_type and self.accelerator_count: + spec["machineSpec"]["acceleratorType"] = accelerator_type + spec["machineSpec"]["acceleratorCount"] = self.accelerator_count + + return spec + + @property + def is_empty(self) -> bool: + """Returns True is replica_count > 0 False otherwise.""" + return self.replica_count <= 0 + + +class _DistributedTrainingSpec(NamedTuple): + """Configuration for distributed training worker pool specs. + + AI Platform Training expects configuration in this order: + [ + chief spec, # can only have one replica + worker spec, + parameter server spec, + evaluator spec + ] + + Usage: + + dist_training_spec = _DistributedTrainingSpec( + chief_spec = _MachineSpec( + replica_count=1, + machine_type='n1-standard-4', + accelerator_count=2, + accelerator_type='NVIDIA_TESLA_K80' + ), + worker_spec = _MachineSpec( + replica_count=10, + machine_type='n1-standard-4', + accelerator_count=2, + accelerator_type='NVIDIA_TESLA_K80' + ) + ) + + """ + + chief_spec: _MachineSpec = _MachineSpec() + worker_spec: _MachineSpec = _MachineSpec() + parameter_server_spec: _MachineSpec = _MachineSpec() + evaluator_spec: _MachineSpec = _MachineSpec() + + @property + def pool_specs( + self, + ) -> List[Dict[str, Union[int, str, Dict[str, Union[int, str]]]]]: + """Return each pools spec in correct order for AI Platform as a list of dicts. + + Also removes specs if they are empty but leaves specs in if there unusual + specifications to not break the ordering in AI Platform Training. + ie. 0 chief replica, 10 worker replica, 3 ps replica + + Returns: + Order list of worker pool specs suitable for AI Platform Training. + """ + if self.chief_spec.replica_count > 1: + raise ValueError("Chief spec replica count cannot be greater than 1.") + + spec_order = [ + self.chief_spec, + self.worker_spec, + self.parameter_server_spec, + self.evaluator_spec, + ] + specs = [s.spec_dict for s in spec_order] + for i in reversed(range(len(spec_order))): + if spec_order[i].is_empty: + specs.pop() + else: + break + return specs + + @classmethod + def chief_worker_pool( + cls, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_count: int = 0, + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + ) -> "_DistributedTrainingSpec": + """Parameterizes Config to support only chief with worker replicas. + + For replica is assigned to chief and the remainder to workers. All spec have the + same machine type, accelerator count, and accelerator type. + + Args: + replica_count (int): + The number of worker replicas. Assigns 1 chief replica and + replica_count - 1 worker replicas. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + + Returns: + _DistributedTrainingSpec representing one chief and n workers all of same + type. If replica_count <= 0 then an empty spec is returned. + """ + if replica_count <= 0: + return cls() + + chief_spec = _MachineSpec( + replica_count=1, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + worker_spec = _MachineSpec( + replica_count=replica_count - 1, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + return cls(chief_spec=chief_spec, worker_spec=worker_spec) + + +class _CustomTrainingJob(_TrainingJob): + """ABC for Custom Training Pipelines.. + """ + + _supported_training_schemas = (schema.training_job.definition.custom_task,) + + def __init__( + self, + display_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """ + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + container_uri (str): + Required: Uri of the training container image in the GCR. + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + self._container_uri = container_uri + + model_predict_schemata = None + if any( + [ + model_instance_schema_uri, + model_parameters_schema_uri, + model_prediction_schema_uri, + ] + ): + model_predict_schemata = gca_model.PredictSchemata( + instance_schema_uri=model_instance_schema_uri, + parameters_schema_uri=model_parameters_schema_uri, + prediction_schema_uri=model_prediction_schema_uri, + ) + + # Create the container spec + env = None + ports = None + + if model_serving_container_environment_variables: + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in model_serving_container_environment_variables.items() + ] + + if model_serving_container_ports: + ports = [ + gca_model.Port(container_port=port) + for port in model_serving_container_ports + ] + + container_spec = gca_model.ModelContainerSpec( + image_uri=model_serving_container_image_uri, + command=model_serving_container_command, + args=model_serving_container_args, + env=env, + ports=ports, + predict_route=model_serving_container_predict_route, + health_route=model_serving_container_health_route, + ) + + # create model payload + self._managed_model = gca_model.Model( + description=model_description, + predict_schemata=model_predict_schemata, + container_spec=container_spec, + encryption_spec=self._model_encryption_spec, + ) + + self._staging_bucket = ( + staging_bucket or initializer.global_config.staging_bucket + ) + + if not self._staging_bucket: + raise RuntimeError( + "staging_bucket should be set in TrainingJob constructor or " + "set using aiplatform.init(staging_bucket='gs://my-bucket')" + ) + + def _prepare_and_validate_run( + self, + model_display_name: Optional[str] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + ) -> Tuple[_DistributedTrainingSpec, Optional[gca_model.Model]]: + """Create worker pool specs and managed model as well validating the run. + + Args: + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + Returns: + Worker pools specs and managed model for run. + + Raises: + RuntimeError if Training job has already been run or model_display_name was + provided but required arguments were not provided in constructor. + + """ + + if self._is_waiting_to_run(): + raise RuntimeError("Custom Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("Custom Training has already run.") + + # if args needed for model is incomplete + if model_display_name and not self._managed_model.container_spec.image_uri: + raise RuntimeError( + """model_display_name was provided but + model_serving_container_image_uri was not provided when this + custom pipeline was constructed. + """ + ) + + # validates args and will raise + worker_pool_specs = _DistributedTrainingSpec.chief_worker_pool( + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ).pool_specs + + managed_model = self._managed_model + if model_display_name: + utils.validate_display_name(model_display_name) + managed_model.display_name = model_display_name + else: + managed_model = None + + return worker_pool_specs, managed_model + + def _prepare_training_task_inputs_and_output_dir( + self, + worker_pool_specs: _DistributedTrainingSpec, + base_output_dir: Optional[str] = None, + ) -> Tuple[Dict, str]: + """Prepares training task inputs and output directory for custom job. + + Args: + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + Returns: + Training task inputs and Output directory for custom job. + """ + + # default directory if not given + base_output_dir = base_output_dir or _timestamped_gcs_dir( + self._staging_bucket, "aiplatform-custom-training" + ) + + _LOGGER.info("Training Output directory:\n%s " % base_output_dir) + + training_task_inputs = { + "workerPoolSpecs": worker_pool_specs, + "baseOutputDirectory": {"output_uri_prefix": base_output_dir}, + } + + return training_task_inputs, base_output_dir + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"Training Pipeline {self.resource_name} is not configured to upload a " + "Model. Create the Training Pipeline with " + "model_serving_container_image_uri and model_display_name passed in. " + "Ensure that your training script saves to model to " + "os.environ['AIP_MODEL_DIR']." + ) + + +# TODO(b/172368325) add scheduling, custom_job.Scheduling +class CustomTrainingJob(_CustomTrainingJob): + """Class to launch a Custom Training Job in AI Platform using a script. + + Takes a training implementation as a python script and executes that script + in Cloud AI Platform Training. + """ + + def __init__( + self, + display_name: str, + script_path: str, + container_uri: str, + requirements: Optional[Sequence[str]] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Constructs a Custom Training Job from a Python script. + + job = aiplatform.CustomTrainingJob( + display_name='test-train', + script_path='test_script.py', + requirements=['pandas', 'numpy'], + container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest', + model_serving_container_image_uri='gcr.io/my-trainer/serving:1', + model_serving_container_predict_route='predict', + model_serving_container_health_route='metadata) + + Usage with Dataset: + + ds = aiplatform.TabularDataset( + 'projects/my-project/locations/us-central1/datasets/12345') + + job.run(ds, replica_count=1, model_display_name='my-trained-model') + + Usage without Dataset: + + job.run(replica_count=1, model_display_name='my-trained-model) + + + TODO(b/169782082) add documentation about traning utilities + To ensure your model gets saved in AI Platform, write your saved model to + os.environ["AIP_MODEL_DIR"] in your provided training script. + + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + script_path (str): Required. Local path to training script. + container_uri (str): + Required: Uri of the training container image in the GCR. + requirements (Sequence[str]): + List of python packages dependencies of script. + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + container_uri=container_uri, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_description=model_description, + staging_bucket=staging_bucket, + ) + + self._requirements = requirements + self._script_path = script_path + + # TODO(b/172365904) add filter split, training_pipeline.FilterSplit + # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit + def run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Runs the custom training job. + + Distributed Training Support: + If replica count = 1 then one chief replica will be provisioned. If + replica_count > 1 the remainder will be provisioned as a worker replica pool. + ie: replica_count = 10 will result in 1 chief and 9 workers + All replicas have same machine_type, accelerator_type, and accelerator_count + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform.If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. Custom training script should + retrieve datasets through passed in environment variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + worker_pool_specs, managed_model = self._prepare_and_validate_run( + model_display_name=model_display_name, + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + # make and copy package + python_packager = _TrainingScriptPythonPackager( + script_path=self._script_path, requirements=self._requirements + ) + + return self._run( + python_packager=python_packager, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + worker_pool_specs=worker_pool_specs, + managed_model=managed_model, + args=args, + base_output_dir=base_output_dir, + bigquery_destination=bigquery_destination, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + sync=sync, + ) + + @base.optional_sync(construct_object_on_arg="managed_model") + def _run( + self, + python_packager: _TrainingScriptPythonPackager, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], + annotation_schema_uri: Optional[str], + worker_pool_specs: _DistributedTrainingSpec, + managed_model: Optional[gca_model.Model] = None, + args: Optional[List[Union[str, float, int]]] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Packages local script and launches training_job. + + Args: + python_packager (_TrainingScriptPythonPackager): + Required. Python Packager pointing to training script locally. + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + managed_model (gca_model.Model): + Model proto if this script produces a Managed Model. + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + package_gcs_uri = python_packager.package_and_copy_to_gcs( + gcs_staging_dir=self._staging_bucket, + project=self.project, + credentials=self.credentials, + ) + + for spec in worker_pool_specs: + spec["pythonPackageSpec"] = { + "executorImageUri": self._container_uri, + "pythonModule": python_packager.module_name, + "packageUris": [package_gcs_uri], + } + + if args: + spec["pythonPackageSpec"]["args"] = args + + ( + training_task_inputs, + base_output_dir, + ) = self._prepare_training_task_inputs_and_output_dir( + worker_pool_specs, base_output_dir + ) + + model = self._run_job( + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=training_task_inputs, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=managed_model, + gcs_destination_uri_prefix=base_output_dir, + bigquery_destination=bigquery_destination, + ) + + return model + + +class CustomContainerTrainingJob(_CustomTrainingJob): + """Class to launch a Custom Training Job in AI Platform using a Container.""" + + def __init__( + self, + display_name: str, + container_uri: str, + command: Sequence[str] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Constructs a Custom Container Training Job. + + job = aiplatform.CustomTrainingJob( + display_name='test-train', + container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest', + command=['python3', 'run_script.py'] + model_serving_container_image_uri='gcr.io/my-trainer/serving:1', + model_serving_container_predict_route='predict', + model_serving_container_health_route='metadata) + + Usage with Dataset: + + ds = aiplatform.TabularDataset( + 'projects/my-project/locations/us-central1/datasets/12345') + + job.run(ds, replica_count=1, model_display_name='my-trained-model') + + Usage without Dataset: + + job.run(replica_count=1, model_display_name='my-trained-model) + + + TODO(b/169782082) add documentation about traning utilities + To ensure your model gets saved in AI Platform, write your saved model to + os.environ["AIP_MODEL_DIR"] in your provided training script. + + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + container_uri (str): + Required: Uri of the training container image in the GCR. + command (Sequence[str]): + The command to be invoked when the container is started. + It overrides the entrypoint instruction in Dockerfile when provided + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + container_uri=container_uri, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_description=model_description, + staging_bucket=staging_bucket, + ) + + self._command = command + + # TODO(b/172365904) add filter split, training_pipeline.FilterSplit + # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit + def run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Runs the custom training job. + + Distributed Training Support: + If replica count = 1 then one chief replica will be provisioned. If + replica_count > 1 the remainder will be provisioned as a worker replica pool. + ie: replica_count = 10 will result in 1 chief and 9 workers + All replicas have same machine_type, accelerator_type, and accelerator_count + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. Custom training script should + retrieve datasets through passed in environment variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError if Training job has already been run, staging_bucket has not + been set, or model_display_name was provided but required arguments + were not provided in constructor. + """ + worker_pool_specs, managed_model = self._prepare_and_validate_run( + model_display_name=model_display_name, + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + return self._run( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + worker_pool_specs=worker_pool_specs, + managed_model=managed_model, + args=args, + base_output_dir=base_output_dir, + bigquery_destination=bigquery_destination, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + sync=sync, + ) + + @base.optional_sync(construct_object_on_arg="managed_model") + def _run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], + annotation_schema_uri: Optional[str], + worker_pool_specs: _DistributedTrainingSpec, + managed_model: Optional[gca_model.Model] = None, + args: Optional[List[Union[str, float, int]]] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Packages local script and launches training_job. + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + managed_model (gca_model.Model): + Model proto if this script produces a Managed Model. + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + for spec in worker_pool_specs: + spec["containerSpec"] = {"imageUri": self._container_uri} + + if self._command: + spec["containerSpec"]["command"] = self._command + + if args: + spec["containerSpec"]["args"] = args + + ( + training_task_inputs, + base_output_dir, + ) = self._prepare_training_task_inputs_and_output_dir( + worker_pool_specs, base_output_dir + ) + + model = self._run_job( + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=training_task_inputs, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=managed_model, + gcs_destination_uri_prefix=base_output_dir, + bigquery_destination=bigquery_destination, + ) + + return model + + +class AutoMLTabularTrainingJob(_TrainingJob): + _supported_training_schemas = (schema.training_job.definition.automl_tabular,) + + def __init__( + self, + display_name: str, + optimization_prediction_type: str, + optimization_objective: Optional[str] = None, + column_transformations: Optional[Union[Dict, List[Dict]]] = None, + optimization_objective_recall_value: Optional[float] = None, + optimization_objective_precision_value: Optional[float] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Tabular Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + optimization_prediction_type (str): + The type of prediction the Model is to produce. + "classification" - Predict one out of multiple target values is + picked for each row. + "regression" - Predict a value based on its relation to other values. + This type is available only to columns that contain + semantically numeric values, i.e. integers or floating + point number, even if stored as e.g. strings. + + optimization_objective (str): + Optional. Objective function the Model is to be optimized towards. The training + task creates a Model that maximizes/minimizes the value of the objective + function over the validation set. + + The supported optimization objectives depend on the prediction type, and + in the case of classification also the number of distinct values in the + target column (two distint values -> binary, 3 or more distinct values + -> multi class). + If the field is not set, the default objective function is used. + + Classification (binary): + "maximize-au-roc" (default) - Maximize the area under the receiver + operating characteristic (ROC) curve. + "minimize-log-loss" - Minimize log loss. + "maximize-au-prc" - Maximize the area under the precision-recall curve. + "maximize-precision-at-recall" - Maximize precision for a specified + recall value. + "maximize-recall-at-precision" - Maximize recall for a specified + precision value. + + Classification (multi class): + "minimize-log-loss" (default) - Minimize log loss. + + Regression: + "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE). + "minimize-mae" - Minimize mean-absolute error (MAE). + "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE). + column_transformations (Optional[Union[Dict, List[Dict]]]): + Optional. Transformations to apply to the input columns (i.e. columns other + than the targetColumn). Each transformation may produce multiple + result values from the column's value, and all are used for training. + When creating transformation for BigQuery Struct column, the column + should be flattened using "." as the delimiter. + If an input column has no transformations on it, such a column is + ignored by the training, except for the targetColumn, which should have + no transformations defined on. + optimization_objective_recall_value (float): + Optional. Required when maximize-precision-at-recall optimizationObjective was + picked, represents the recall value at which the optimization is done. + + The minimum value is 0 and the maximum is 1.0. + optimization_objective_precision_value (float): + Optional. Required when maximize-recall-at-precision optimizationObjective was + picked, represents the precision value at which the optimization is + done. + + The minimum value is 0 and the maximum is 1.0. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + self._column_transformations = column_transformations + self._optimization_objective = optimization_objective + self._optimization_prediction_type = optimization_prediction_type + self._optimization_objective_recall_value = optimization_objective_recall_value + self._optimization_objective_precision_value = ( + optimization_objective_precision_value + ) + + def run( + self, + dataset: datasets.TabularDataset, + target_column: str, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TabularDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + disable_early_stopping (bool): + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError if Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Tabular Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Tabular Training has already run.") + + return self._run( + dataset=dataset, + target_column=target_column, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + weight_column=weight_column, + budget_milli_node_hours=budget_milli_node_hours, + model_display_name=model_display_name, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.TabularDataset, + target_column: str, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TabularDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + disable_early_stopping (bool): + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + training_task_definition = schema.training_job.definition.automl_tabular + + training_task_inputs_dict = { + # required inputs + "targetColumn": target_column, + "transformations": self._column_transformations, + "trainBudgetMilliNodeHours": budget_milli_node_hours, + # optional inputs + "weightColumnName": weight_column, + "disableEarlyStopping": disable_early_stopping, + "optimizationObjective": self._optimization_objective, + "predictionType": self._optimization_prediction_type, + "optimizationObjectiveRecallValue": self._optimization_objective_recall_value, + "optimizationObjectivePrecisionValue": self._optimization_objective_precision_value, + } + + if model_display_name is None: + model_display_name = self._display_name + + model = gca_model.Model( + display_name=model_display_name, + encryption_spec=self._model_encryption_spec, + ) + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=model, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"Training Pipeline {self.resource_name} is not configured to upload a " + "Model." + ) + + +class AutoMLImageTrainingJob(_TrainingJob): + _supported_training_schemas = ( + schema.training_job.definition.automl_image_classification, + schema.training_job.definition.automl_image_object_detection, + ) + + def __init__( + self, + display_name: str, + prediction_type: str = "classification", + multi_label: bool = False, + model_type: str = "CLOUD", + base_model: Optional[models.Model] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Image Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + prediction_type (str): + The type of prediction the Model is to produce, one of: + "classification" - Predict one out of multiple target values is + picked for each row. + "object_detection" - Predict a value based on its relation to other values. + This type is available only to columns that contain + semantically numeric values, i.e. integers or floating + point number, even if stored as e.g. strings. + multi_label: bool = False + Required. Default is False. + If false, a single-label (multi-class) Model will be trained + (i.e. assuming that for each image just up to one annotation may be + applicable). If true, a multi-label Model will be trained (i.e. + assuming that for each image multiple annotations may be applicable). + + This is only applicable for the "classification" prediction_type and + will be ignored otherwise. + model_type: str = "CLOUD" + Required. One of the following: + "CLOUD" - Default for Image Classification. + A Model best tailored to be used within Google Cloud, and + which cannot be exported. + "CLOUD_HIGH_ACCURACY_1" - Default for Image Object Detection. + A model best tailored to be used within Google Cloud, and + which cannot be exported. Expected to have a higher latency, + but should also have a higher prediction quality than other + cloud models. + "CLOUD_LOW_LATENCY_1" - A model best tailored to be used within + Google Cloud, and which cannot be exported. Expected to have a + low latency, but may have lower prediction quality than other + cloud models. + "MOBILE_TF_LOW_LATENCY_1" - A model that, in addition to being + available within Google Cloud, can also be exported as TensorFlow + or Core ML model and used on a mobile or edge device afterwards. + Expected to have low latency, but may have lower prediction + quality than other mobile models. + "MOBILE_TF_VERSATILE_1" - A model that, in addition to being + available within Google Cloud, can also be exported as TensorFlow + or Core ML model and used on a mobile or edge device with afterwards. + "MOBILE_TF_HIGH_ACCURACY_1" - A model that, in addition to being + available within Google Cloud, can also be exported as TensorFlow + or Core ML model and used on a mobile or edge device afterwards. + Expected to have a higher latency, but should also have a higher + prediction quality than other mobile models. + base_model: Optional[models.Model] = None + Optional. Only permitted for Image Classification models. + If it is specified, the new model will be trained based on the `base` model. + Otherwise, the new model will be trained from scratch. The `base` model + must be in the same Project and Location as the new Model to train, + and have the same model_type. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Raises: + ValueError: When an invalid prediction_type or model_type is provided. + """ + + valid_model_types = constants.AUTOML_IMAGE_PREDICTION_MODEL_TYPES.get( + prediction_type, None + ) + + if not valid_model_types: + raise ValueError( + f"'{prediction_type}' is not a supported prediction type for AutoML Image Training. " + f"Please choose one of: {tuple(constants.AUTOML_IMAGE_PREDICTION_MODEL_TYPES.keys())}." + ) + + # Override default model_type for object_detection + if model_type == "CLOUD" and prediction_type == "object_detection": + model_type = "CLOUD_HIGH_ACCURACY_1" + + if model_type not in valid_model_types: + raise ValueError( + f"'{model_type}' is not a supported model_type for prediction_type of '{prediction_type}'. " + f"Please choose one of: {tuple(valid_model_types)}" + ) + + if base_model and prediction_type != "classification": + raise ValueError( + "Training with a `base_model` is only supported in AutoML Image Classification. " + f"However '{prediction_type}' was provided as `prediction_type`." + ) + + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + self._model_type = model_type + self._prediction_type = prediction_type + self._multi_label = multi_label + self._base_model = base_model + + def run( + self, + dataset: datasets.ImageDataset, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the AutoML Image training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.ImageDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split: float = 0.8 + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + budget_milli_node_hours: int = 1000 + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If not provided upon creation, the job's display_name is used. + disable_early_stopping: bool = False + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync: bool = True + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError: If Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Image Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Image Training has already run.") + + return self._run( + dataset=dataset, + base_model=self._base_model, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + model_display_name=model_display_name, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.ImageDataset, + base_model: Optional[models.Model] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.ImageDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + base_model: Optional[models.Model] = None + Optional. Only permitted for Image Classification models. + If it is specified, the new model will be trained based on the `base` model. + Otherwise, the new model will be trained from scratch. The `base` model + must be in the same Project and Location as the new Model to train, + and have the same model_type. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If a `base_model` was provided, the display_name in the + base_model will be overritten with this value. If not provided upon + creation, the job's display_name is used. + disable_early_stopping (bool): + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + # Retrieve the objective-specific training task schema based on prediction_type + training_task_definition = getattr( + schema.training_job.definition, f"automl_image_{self._prediction_type}" + ) + + training_task_inputs_dict = { + # required inputs + "modelType": self._model_type, + "budgetMilliNodeHours": budget_milli_node_hours, + # optional inputs + "disableEarlyStopping": disable_early_stopping, + } + + if self._prediction_type == "classification": + training_task_inputs_dict["multiLabel"] = self._multi_label + + # gca Model to be trained + model_tbt = gca_model.Model(encryption_spec=self._model_encryption_spec) + + model_tbt.display_name = model_display_name or self._display_name + + if base_model: + # Use provided base_model to pass to model_to_upload causing the + # description and labels from base_model to be passed onto the new model + model_tbt.description = getattr(base_model._gca_resource, "description") + model_tbt.labels = getattr(base_model._gca_resource, "labels") + + # Set ID of AI Platform Model to base this training job off of + training_task_inputs_dict["baseModelId"] = base_model.name + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + model=model_tbt, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"AutoML Image Training Pipeline {self.resource_name} is not " + "configured to upload a Model." + ) + + +class CustomPythonPackageTrainingJob(_CustomTrainingJob): + """Class to launch a Custom Training Job in AI Platform using a Python Package. + + Takes a training implementation as a python package and executes that package + in Cloud AI Platform Training. + """ + + def __init__( + self, + display_name: str, + python_package_gcs_uri: str, + python_module_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Constructs a Custom Training Job from a Python Package. + + job = aiplatform.CustomPythonPackageTrainingJob( + display_name='test-train', + python_package_gcs_uri='gs://my-bucket/my-python-package.tar.gz', + python_module_name='my-training-python-package.task', + container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest', + model_serving_container_image_uri='gcr.io/my-trainer/serving:1', + model_serving_container_predict_route='predict', + model_serving_container_health_route='metadata + ) + + Usage with Dataset: + + ds = aiplatform.TabularDataset( + 'projects/my-project/locations/us-central1/datasets/12345' + ) + + job.run( + ds, + replica_count=1, + model_display_name='my-trained-model' + ) + + Usage without Dataset: + + job.run( + replica_count=1, + model_display_name='my-trained-model' + ) + + To ensure your model gets saved in AI Platform, write your saved model to + os.environ["AIP_MODEL_DIR"] in your provided training script. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + python_package_gcs_uri (str): + Required: GCS location of the training python package. + python_module_name (str): + Required: The module name of the training python package. + container_uri (str): + Required: Uri of the training container image in the GCR. + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + container_uri=container_uri, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_description=model_description, + staging_bucket=staging_bucket, + ) + + self._package_gcs_uri = python_package_gcs_uri + self._python_module = python_module_name + + def run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Runs the custom training job. + + Distributed Training Support: + If replica count = 1 then one chief replica will be provisioned. If + replica_count > 1 the remainder will be provisioned as a worker replica pool. + ie: replica_count = 10 will result in 1 chief and 9 workers + All replicas have same machine_type, accelerator_type, and accelerator_count + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform.If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. Custom training script should + retrieve datasets through passed in environement variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + worker_pool_specs, managed_model = self._prepare_and_validate_run( + model_display_name=model_display_name, + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + return self._run( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + worker_pool_specs=worker_pool_specs, + managed_model=managed_model, + args=args, + base_output_dir=base_output_dir, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + bigquery_destination=bigquery_destination, + sync=sync, + ) + + @base.optional_sync(construct_object_on_arg="managed_model") + def _run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], + annotation_schema_uri: Optional[str], + worker_pool_specs: _DistributedTrainingSpec, + managed_model: Optional[gca_model.Model] = None, + args: Optional[List[Union[str, float, int]]] = None, + base_output_dir: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + bigquery_destination: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Packages local script and launches training_job. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + managed_model (gca_model.Model): + Model proto if this script produces a Managed Model. + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + for spec in worker_pool_specs: + spec["pythonPackageSpec"] = { + "executorImageUri": self._container_uri, + "pythonModule": self._python_module, + "packageUris": [self._package_gcs_uri], + } + + if args: + spec["pythonPackageSpec"]["args"] = args + + ( + training_task_inputs, + base_output_dir, + ) = self._prepare_training_task_inputs_and_output_dir( + worker_pool_specs, base_output_dir + ) + + model = self._run_job( + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=training_task_inputs, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=managed_model, + gcs_destination_uri_prefix=base_output_dir, + bigquery_destination=bigquery_destination, + ) + + return model + + +class AutoMLVideoTrainingJob(_TrainingJob): + + _supported_training_schemas = ( + schema.training_job.definition.automl_video_classification, + schema.training_job.definition.automl_video_object_tracking, + schema.training_job.definition.automl_video_action_recognition, + ) + + def __init__( + self, + display_name: str, + prediction_type: str = "classification", + model_type: str = "CLOUD", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Video Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + prediction_type (str): + The type of prediction the Model is to produce, one of: + "classification" - A video classification model classifies shots + and segments in your videos according to your own defined labels. + "object_tracking" - A video object tracking model detects and tracks + multiple objects in shots and segments. You can use these + models to track objects in your videos according to your + own pre-defined, custom labels. + "action_recognition" - A video action reconition model pinpoints + the location of actions with short temporal durations (~1 second). + model_type: str = "CLOUD" + Required. One of the following: + "CLOUD" - available for "classification", "object_tracking" and "action_recognition" + A Model best tailored to be used within Google Cloud, + and which cannot be exported. + "MOBILE_VERSATILE_1" - available for "classification", "object_tracking" and "action_recognition" + A model that, in addition to being available within Google + Cloud, can also be exported (see ModelService.ExportModel) + as a TensorFlow or TensorFlow Lite model and used on a + mobile or edge device with afterwards. + "MOBILE_CORAL_VERSATILE_1" - available only for "object_tracking" + A versatile model that is meant to be exported (see + ModelService.ExportModel) and used on a Google Coral device. + "MOBILE_CORAL_LOW_LATENCY_1" - available only for "object_tracking" + A model that trades off quality for low latency, to be + exported (see ModelService.ExportModel) and used on a + Google Coral device. + "MOBILE_JETSON_VERSATILE_1" - available only for "object_tracking" + A versatile model that is meant to be exported (see + ModelService.ExportModel) and used on an NVIDIA Jetson device. + "MOBILE_JETSON_LOW_LATENCY_1" - available only for "object_tracking" + A model that trades off quality for low latency, to be + exported (see ModelService.ExportModel) and used on an + NVIDIA Jetson device. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Raises: + ValueError: When an invalid prediction_type and/or model_type is provided. + """ + valid_model_types = constants.AUTOML_VIDEO_PREDICTION_MODEL_TYPES.get( + prediction_type, None + ) + + if not valid_model_types: + raise ValueError( + f"'{prediction_type}' is not a supported prediction type for AutoML Video Training. " + f"Please choose one of: {tuple(constants.AUTOML_VIDEO_PREDICTION_MODEL_TYPES.keys())}." + ) + + if model_type not in valid_model_types: + raise ValueError( + f"'{model_type}' is not a supported model_type for prediction_type of '{prediction_type}'. " + f"Please choose one of: {tuple(valid_model_types)}" + ) + + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + self._model_type = model_type + self._prediction_type = prediction_type + + def run( + self, + dataset: datasets.VideoDataset, + training_fraction_split: float = 0.8, + test_fraction_split: float = 0.2, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the AutoML Image training job and returns a model. + + Data fraction splits: + ``training_fraction_split``, and ``test_fraction_split`` may optionally + be provided, they must sum to up to 1. If none of the fractions are set, + by default roughly 80% of data will be used for training, and 20% for test. + + Args: + dataset (datasets.VideoDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split: float = 0.8 + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + test_fraction_split: float = 0.2 + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If not provided upon creation, the job's display_name is used. + sync: bool = True + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError: If Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Video Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Video Training has already run.") + + return self._run( + dataset=dataset, + training_fraction_split=training_fraction_split, + test_fraction_split=test_fraction_split, + model_display_name=model_display_name, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.VideoDataset, + training_fraction_split: float = 0.8, + test_fraction_split: float = 0.2, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, and ``test_fraction_split`` may optionally + be provided, they must sum to up to 1. If none of the fractions are set, + by default roughly 80% of data will be used for training, and 20% for test. + + Args: + dataset (datasets.VideoDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If a `base_model` was provided, the display_name in the + base_model will be overritten with this value. If not provided upon + creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + # Retrieve the objective-specific training task schema based on prediction_type + training_task_definition = getattr( + schema.training_job.definition, f"automl_video_{self._prediction_type}" + ) + + training_task_inputs_dict = { + "modelType": self._model_type, + } + + # gca Model to be trained + model_tbt = gca_model.Model(encryption_spec=self._model_encryption_spec) + model_tbt.display_name = model_display_name or self._display_name + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=0.0, + test_fraction_split=test_fraction_split, + model=model_tbt, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"AutoML Video Training Pipeline {self.resource_name} is not " + "configured to upload a Model." + ) + + +class AutoMLTextTrainingJob(_TrainingJob): + _supported_training_schemas = ( + schema.training_job.definition.automl_text_classification, + schema.training_job.definition.automl_text_extraction, + schema.training_job.definition.automl_text_sentiment, + ) + + def __init__( + self, + display_name: str, + prediction_type: str, + multi_label: bool = False, + sentiment_max: int = 10, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Text Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + prediction_type (str): + The type of prediction the Model is to produce, one of: + "classification" - A classification model analyzes text data and + returns a list of categories that apply to the text found in the data. + AI Platform offers both single-label and multi-label text classification models. + "extraction" - An entity extraction model inspects text data + for known entities referenced in the data and + labels those entities in the text. + "sentiment" - A sentiment analysis model inspects text data and identifies the + prevailing emotional opinion within it, especially to determine a writer's attitude + as positive, negative, or neutral. + multi_label (bool): + Required and only applicable for text classification task. If false, a single-label (multi-class) Model will be trained (i.e. + assuming that for each text snippet just up to one annotation may be + applicable). If true, a multi-label Model will be trained (i.e. + assuming that for each text snippet multiple annotations may be + applicable). + sentiment_max (int): + Required and only applicable for sentiment task. A sentiment is expressed as an integer + ordinal, where higher value means a more + positive sentiment. The range of sentiments that + will be used is between 0 and sentimentMax + (inclusive on both ends), and all the values in + the range must be represented in the dataset + before a model can be created. + Only the Annotations with this sentimentMax will + be used for training. sentimentMax value must be + between 1 and 10 (inclusive). + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + training_task_definition: str + training_task_inputs_dict: proto.Message + + if prediction_type == "classification": + training_task_definition = ( + schema.training_job.definition.automl_text_classification + ) + + training_task_inputs_dict = training_job_inputs.AutoMlTextClassificationInputs( + multi_label=multi_label + ) + elif prediction_type == "extraction": + training_task_definition = ( + schema.training_job.definition.automl_text_extraction + ) + + training_task_inputs_dict = training_job_inputs.AutoMlTextExtractionInputs() + elif prediction_type == "sentiment": + training_task_definition = ( + schema.training_job.definition.automl_text_sentiment + ) + + training_task_inputs_dict = training_job_inputs.AutoMlTextSentimentInputs( + sentiment_max=sentiment_max + ) + else: + raise ValueError( + "Prediction type must be one of 'classification', 'extraction', or 'sentiment'." + ) + + self._training_task_definition = training_task_definition + self._training_task_inputs_dict = training_task_inputs_dict + + def run( + self, + dataset: datasets.TextDataset, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TextDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + training_fraction_split: float = 0.8 + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. + The name can be up to 128 characters long and can consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource. + + Raises: + RuntimeError if Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Text Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Text Training has already run.") + + return self._run( + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + model_display_name=model_display_name, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.TextDataset, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TextDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For Text Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + if model_display_name is None: + model_display_name = self._display_name + + model = gca_model.Model( + display_name=model_display_name, + encryption_spec=self._model_encryption_spec, + ) + + return self._run_job( + training_task_definition=self._training_task_definition, + training_task_inputs=self._training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=None, + model=model, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"AutoML Text Training Pipeline {self.resource_name} is not " + "configured to upload a Model." + ) diff --git a/google/cloud/aiplatform/training_utils.py b/google/cloud/aiplatform/training_utils.py new file mode 100644 index 0000000000..a93ecaa1ce --- /dev/null +++ b/google/cloud/aiplatform/training_utils.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os + +from typing import Dict, Optional + + +class EnvironmentVariables: + """Passes on OS' environment variables""" + + @property + def training_data_uri(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for training data. None if + environment variable not set. + """ + return os.environ.get("AIP_TRAINING_DATA_URI") + + @property + def validation_data_uri(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for validation data. None + if environment variable not set. + """ + return os.environ.get("AIP_VALIDATION_DATA_URI") + + @property + def test_data_uri(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for test data. None if + environment variable not set. + """ + return os.environ.get("AIP_TEST_DATA_URI") + + @property + def model_dir(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for saving model artefacts. + None if environment variable not set. + """ + return os.environ.get("AIP_MODEL_DIR") + + @property + def checkpoint_dir(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for saving checkpoints. + None if environment variable not set. + """ + return os.environ.get("AIP_CHECKPOINT_DIR") + + @property + def tensorboard_log_dir(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for saving TensorBoard logs. + None if environment variable not set. + """ + return os.environ.get("AIP_TENSORBOARD_LOG_DIR") + + @property + def cluster_spec(self) -> Optional[Dict]: + """ + Returns: + json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#cluster-variables + None if environment variable not set. + """ + cluster_spec_env = os.environ.get("CLUSTER_SPEC") + if cluster_spec_env is not None: + return json.loads(cluster_spec_env) + else: + return None + + @property + def tf_config(self) -> Optional[Dict]: + """ + Returns: + json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#tf-config + None if environment variable not set. + """ + tf_config_env = os.environ.get("TF_CONFIG") + if tf_config_env is not None: + return json.loads(tf_config_env) + else: + return None diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py new file mode 100644 index 0000000000..7584c7d02e --- /dev/null +++ b/google/cloud/aiplatform/utils.py @@ -0,0 +1,469 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import abc +from collections import namedtuple +import logging +import re +from typing import Any, Match, Optional, Type, TypeVar, Tuple + +from google.api_core import client_options +from google.api_core import gapic_v1 +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform.compat.services import ( + dataset_service_client_v1beta1, + endpoint_service_client_v1beta1, + job_service_client_v1beta1, + model_service_client_v1beta1, + pipeline_service_client_v1beta1, + prediction_service_client_v1beta1, +) +from google.cloud.aiplatform.compat.services import ( + dataset_service_client_v1, + endpoint_service_client_v1, + job_service_client_v1, + model_service_client_v1, + pipeline_service_client_v1, + prediction_service_client_v1, +) + +from google.cloud.aiplatform.compat.types import ( + accelerator_type as gca_accelerator_type, +) + +AiPlatformServiceClient = TypeVar( + "AiPlatformServiceClient", + # v1beta1 + dataset_service_client_v1beta1.DatasetServiceClient, + endpoint_service_client_v1beta1.EndpointServiceClient, + model_service_client_v1beta1.ModelServiceClient, + prediction_service_client_v1beta1.PredictionServiceClient, + pipeline_service_client_v1beta1.PipelineServiceClient, + job_service_client_v1beta1.JobServiceClient, + # v1 + dataset_service_client_v1.DatasetServiceClient, + endpoint_service_client_v1.EndpointServiceClient, + model_service_client_v1.ModelServiceClient, + prediction_service_client_v1.PredictionServiceClient, + pipeline_service_client_v1.PipelineServiceClient, + job_service_client_v1.JobServiceClient, +) + +# TODO(b/170334193): Add support for resource names with non-integer IDs +# TODO(b/170334098): Add support for resource names more than one level deep +RESOURCE_NAME_PATTERN = re.compile( + r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P\w+)\/(?P\d+)$" +) +RESOURCE_ID_PATTERN = re.compile(r"^\d+$") + +Fields = namedtuple("Fields", ["project", "location", "resource", "id"],) + + +def _match_to_fields(match: Match) -> Optional[Fields]: + """Normalize RegEx groups from resource name pattern Match to class Fields""" + if not match: + return None + + return Fields( + project=match["project"], + location=match["location"], + resource=match["resource"], + id=match["id"], + ) + + +def validate_id(resource_id: str) -> bool: + """Validate int64 resource ID number""" + return bool(RESOURCE_ID_PATTERN.match(resource_id)) + + +def extract_fields_from_resource_name( + resource_name: str, resource_noun: Optional[str] = None +) -> Optional[Fields]: + """Validates and returns extracted fields from a fully-qualified resource name. + Returns None if name is invalid. + + Args: + resource_name (str): + Required. A fully-qualified AI Platform (Unified) resource name + + resource_noun (str): + A plural resource noun to validate the resource name against. + For example, you would pass "datasets" to validate + "projects/123/locations/us-central1/datasets/456". + + Returns: + fields (Fields): + A named tuple containing four extracted fields from a resource name: + project, location, resource, and id. These fields can be used for + subsequent method calls in the SDK. + """ + fields = _match_to_fields(RESOURCE_NAME_PATTERN.match(resource_name)) + + if not fields: + return None + if resource_noun and fields.resource != resource_noun: + return None + + return fields + + +def full_resource_name( + resource_name: str, + resource_noun: str, + project: Optional[str] = None, + location: Optional[str] = None, +) -> str: + """ + Returns fully qualified resource name. + + Args: + resource_name (str): + Required. A fully-qualified AI Platform (Unified) resource name or + resource ID. + resource_noun (str): + A plural resource noun to validate the resource name against. + For example, you would pass "datasets" to validate + "projects/123/locations/us-central1/datasets/456". + project (str): + Optional project to retrieve resource_noun from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve resource_noun from. If not set, location + set in aiplatform.init will be used. + + Returns: + resource_name (str): + A fully-qualified AI Platform (Unified) resource name. + + Raises: + ValueError: + If resource name, resource ID or project ID not provided. + """ + validate_resource_noun(resource_noun) + # Fully qualified resource name, i.e. "projects/.../locations/.../datasets/12345" + valid_name = extract_fields_from_resource_name( + resource_name=resource_name, resource_noun=resource_noun + ) + + user_project = project or initializer.global_config.project + user_location = location or initializer.global_config.location + + # Partial resource name (i.e. "12345") with known project and location + if ( + not valid_name + and validate_project(user_project) + and validate_region(user_location) + and validate_id(resource_name) + ): + resource_name = f"projects/{user_project}/locations/{user_location}/{resource_noun}/{resource_name}" + # Invalid resource_name parameter + elif not valid_name: + raise ValueError(f"Please provide a valid {resource_noun[:-1]} name or ID") + + return resource_name + + +# TODO(b/172286889) validate resource noun +def validate_resource_noun(resource_noun: str) -> bool: + """Validates resource noun. + + Args: + resource_noun: resource noun to validate + Returns: + bool: True if no errors raised + Raises: + ValueError: If resource noun not supported. + """ + if resource_noun: + return True + raise ValueError("Please provide a valid resource noun") + + +# TODO(b/172288287) validate project +def validate_project(project: str) -> bool: + """Validates project. + + Args: + project: project to validate + Returns: + bool: True if no errors raised + Raises: + ValueError: If project does not exist. + """ + if project: + return True + raise ValueError("Please provide a valid project ID") + + +# TODO(b/172932277) verify display name only contains utf-8 chars +def validate_display_name(display_name: str): + """Verify display name is at most 128 chars + + Args: + display_name: display name to verify + Raises: + ValueError: display name is longer than 128 characters + """ + if len(display_name) > 128: + raise ValueError("Display name needs to be less than 128 characters.") + + +def validate_region(region: str) -> bool: + """Validates region against supported regions. + + Args: + region: region to validate + Returns: + bool: True if no errors raised + Raises: + ValueError: If region is not in supported regions. + """ + if not region: + raise ValueError( + f"Please provide a region, select from {constants.SUPPORTED_REGIONS}" + ) + + region = region.lower() + if region not in constants.SUPPORTED_REGIONS: + raise ValueError( + f"Unsupported region for AI Platform, select from {constants.SUPPORTED_REGIONS}" + ) + + return True + + +def validate_accelerator_type(accelerator_type: str) -> bool: + """Validates user provided accelerator_type string for training and prediction + + Args: + accelerator_type (str): + Represents a hardware accelerator type. + Returns: + bool: True if valid accelerator_type + Raises: + ValueError if accelerator type is invalid. + """ + if accelerator_type not in gca_accelerator_type.AcceleratorType._member_names_: + raise ValueError( + f"Given accelerator_type `{accelerator_type}` invalid. " + f"Choose one of {gca_accelerator_type.AcceleratorType._member_names_}" + ) + return True + + +def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optional[str]]: + """Given a complete GCS path, return the bucket name and prefix as a tuple. + + Example Usage: + + bucket, prefix = extract_bucket_and_prefix_from_gcs_path( + "gs://example-bucket/path/to/folder" + ) + + # bucket = "example-bucket" + # prefix = "path/to/folder" + + Args: + gcs_path (str): + Required. A full path to a Google Cloud Storage folder or resource. + Can optionally include "gs://" prefix or end in a trailing slash "/". + + Returns: + Tuple[str, Optional[str]] + A (bucket, prefix) pair from provided GCS path. If a prefix is not + present, a None will be returned in its place. + """ + if gcs_path.startswith("gs://"): + gcs_path = gcs_path[5:] + if gcs_path.endswith("/"): + gcs_path = gcs_path[:-1] + + gcs_parts = gcs_path.split("/", 1) + gcs_bucket = gcs_parts[0] + gcs_blob_prefix = None if len(gcs_parts) == 1 else gcs_parts[1] + + return (gcs_bucket, gcs_blob_prefix) + + +class ClientWithOverride: + class WrappedClient: + """Wrapper class for client that creates client at API invocation time.""" + + def __init__( + self, + client_class: Type[AiPlatformServiceClient], + client_options: client_options.ClientOptions, + client_info: gapic_v1.client_info.ClientInfo, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Stores parameters needed to instantiate client. + + client_class (AiPlatformServiceClient): + Required. Class of the client to use. + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. + """ + + self._client_class = client_class + self._credentials = credentials + self._client_options = client_options + self._client_info = client_info + + def __getattr__(self, name: str) -> Any: + """Instantiates client and returns attribute of the client.""" + temporary_client = self._client_class( + credentials=self._credentials, + client_options=self._client_options, + client_info=self._client_info, + ) + return getattr(temporary_client, name) + + @property + @abc.abstractmethod + def _is_temporary(self) -> bool: + pass + + @property + @classmethod + @abc.abstractmethod + def _default_version(self) -> str: + pass + + @property + @classmethod + @abc.abstractmethod + def _version_map(self) -> Tuple: + pass + + def __init__( + self, + client_options: client_options.ClientOptions, + client_info: gapic_v1.client_info.ClientInfo, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Stores parameters needed to instantiate client. + + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. + """ + + self._clients = { + version: self.WrappedClient( + client_class=client_class, + client_options=client_options, + client_info=client_info, + credentials=credentials, + ) + if self._is_temporary + else client_class( + client_options=client_options, + client_info=client_info, + credentials=credentials, + ) + for version, client_class in self._version_map + } + + def __getattr__(self, name: str) -> Any: + """Instantiates client and returns attribute of the client.""" + return getattr(self._clients[self._default_version], name) + + def select_version(self, version: str) -> AiPlatformServiceClient: + return self._clients[version] + + +class DatasetClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, dataset_service_client_v1.DatasetServiceClient), + (compat.V1BETA1, dataset_service_client_v1beta1.DatasetServiceClient), + ) + + +class EndpointClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, endpoint_service_client_v1.EndpointServiceClient), + (compat.V1BETA1, endpoint_service_client_v1beta1.EndpointServiceClient), + ) + + +class JobpointClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, job_service_client_v1.JobServiceClient), + (compat.V1BETA1, job_service_client_v1beta1.JobServiceClient), + ) + + +class ModelClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, model_service_client_v1.ModelServiceClient), + (compat.V1BETA1, model_service_client_v1beta1.ModelServiceClient), + ) + + +class PipelineClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, pipeline_service_client_v1.PipelineServiceClient), + (compat.V1BETA1, pipeline_service_client_v1beta1.PipelineServiceClient), + ) + + +class PredictionClientWithOverride(ClientWithOverride): + _is_temporary = False + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, prediction_service_client_v1.PredictionServiceClient), + (compat.V1BETA1, prediction_service_client_v1beta1.PredictionServiceClient), + ) + + +AiPlatformServiceClientWithOverride = TypeVar( + "AiPlatformServiceClientWithOverride", + DatasetClientWithOverride, + EndpointClientWithOverride, + JobpointClientWithOverride, + ModelClientWithOverride, + PipelineClientWithOverride, + PredictionClientWithOverride, +) + + +class LoggingWarningFilter(logging.Filter): + def filter(self, record): + return record.levelname == logging.WARNING diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index dfc64069bd..ff4505543b 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -237,7 +237,7 @@ async def create_dataset(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_dataset, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -322,7 +322,7 @@ async def get_dataset(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_dataset, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -415,7 +415,7 @@ async def update_dataset(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.update_dataset, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -495,7 +495,7 @@ async def list_datasets(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_datasets, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -595,7 +595,7 @@ async def delete_dataset(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_dataset, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -696,7 +696,7 @@ async def import_data(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.import_data, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -795,7 +795,7 @@ async def export_data(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.export_data, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -884,7 +884,7 @@ async def list_data_items(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_data_items, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -972,7 +972,7 @@ async def get_annotation_spec(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_annotation_spec, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1054,7 +1054,7 @@ async def list_annotations(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_annotations, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py index 15daeb6369..9f9b80b9a4 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py @@ -111,52 +111,52 @@ def _prep_wrapped_messages(self, client_info): self._wrapped_methods = { self.create_dataset: gapic_v1.method.wrap_method( self.create_dataset, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_dataset: gapic_v1.method.wrap_method( self.get_dataset, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.update_dataset: gapic_v1.method.wrap_method( self.update_dataset, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_datasets: gapic_v1.method.wrap_method( self.list_datasets, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_dataset: gapic_v1.method.wrap_method( self.delete_dataset, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.import_data: gapic_v1.method.wrap_method( self.import_data, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.export_data: gapic_v1.method.wrap_method( self.export_data, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_data_items: gapic_v1.method.wrap_method( self.list_data_items, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_annotation_spec: gapic_v1.method.wrap_method( self.get_annotation_spec, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_annotations: gapic_v1.method.wrap_method( self.list_annotations, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index ab1a3d3daf..e0e8a26c65 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -228,7 +228,7 @@ async def create_endpoint(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -314,7 +314,7 @@ async def get_endpoint(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -395,7 +395,7 @@ async def list_endpoints(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_endpoints, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -492,7 +492,7 @@ async def update_endpoint(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.update_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -583,7 +583,7 @@ async def delete_endpoint(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -712,7 +712,7 @@ async def deploy_model(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.deploy_model, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -832,7 +832,7 @@ async def undeploy_model(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.undeploy_model, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py index 43520356ad..65e049d43f 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -110,37 +110,37 @@ def _prep_wrapped_messages(self, client_info): self._wrapped_methods = { self.create_endpoint: gapic_v1.method.wrap_method( self.create_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_endpoint: gapic_v1.method.wrap_method( self.get_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_endpoints: gapic_v1.method.wrap_method( self.list_endpoints, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.update_endpoint: gapic_v1.method.wrap_method( self.update_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_endpoint: gapic_v1.method.wrap_method( self.delete_endpoint, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.deploy_model: gapic_v1.method.wrap_method( self.deploy_model, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.undeploy_model: gapic_v1.method.wrap_method( self.undeploy_model, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index 5d9a5d68b5..8b2c365b88 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -256,7 +256,7 @@ async def create_custom_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -339,7 +339,7 @@ async def get_custom_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -420,7 +420,7 @@ async def list_custom_jobs(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_custom_jobs, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -520,7 +520,7 @@ async def delete_custom_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -611,7 +611,7 @@ async def cancel_custom_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.cancel_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -696,7 +696,7 @@ async def create_data_labeling_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -775,7 +775,7 @@ async def get_data_labeling_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -855,7 +855,7 @@ async def list_data_labeling_jobs(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_data_labeling_jobs, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -956,7 +956,7 @@ async def delete_data_labeling_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1037,7 +1037,7 @@ async def cancel_data_labeling_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.cancel_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1124,7 +1124,7 @@ async def create_hyperparameter_tuning_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1205,7 +1205,7 @@ async def get_hyperparameter_tuning_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1286,7 +1286,7 @@ async def list_hyperparameter_tuning_jobs(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_hyperparameter_tuning_jobs, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1387,7 +1387,7 @@ async def delete_hyperparameter_tuning_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1481,7 +1481,7 @@ async def cancel_hyperparameter_tuning_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.cancel_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1571,7 +1571,7 @@ async def create_batch_prediction_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1654,7 +1654,7 @@ async def get_batch_prediction_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1735,7 +1735,7 @@ async def list_batch_prediction_jobs(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_batch_prediction_jobs, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1837,7 +1837,7 @@ async def delete_batch_prediction_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1929,7 +1929,7 @@ async def cancel_batch_prediction_job(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.cancel_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1/services/job_service/transports/base.py index f3ee6dc74a..0292f60059 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/base.py @@ -117,102 +117,102 @@ def _prep_wrapped_messages(self, client_info): self._wrapped_methods = { self.create_custom_job: gapic_v1.method.wrap_method( self.create_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_custom_job: gapic_v1.method.wrap_method( self.get_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_custom_jobs: gapic_v1.method.wrap_method( self.list_custom_jobs, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_custom_job: gapic_v1.method.wrap_method( self.delete_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.cancel_custom_job: gapic_v1.method.wrap_method( self.cancel_custom_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.create_data_labeling_job: gapic_v1.method.wrap_method( self.create_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_data_labeling_job: gapic_v1.method.wrap_method( self.get_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_data_labeling_jobs: gapic_v1.method.wrap_method( self.list_data_labeling_jobs, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_data_labeling_job: gapic_v1.method.wrap_method( self.delete_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.cancel_data_labeling_job: gapic_v1.method.wrap_method( self.cancel_data_labeling_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.create_hyperparameter_tuning_job: gapic_v1.method.wrap_method( self.create_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_hyperparameter_tuning_job: gapic_v1.method.wrap_method( self.get_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_hyperparameter_tuning_jobs: gapic_v1.method.wrap_method( self.list_hyperparameter_tuning_jobs, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_hyperparameter_tuning_job: gapic_v1.method.wrap_method( self.delete_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.cancel_hyperparameter_tuning_job: gapic_v1.method.wrap_method( self.cancel_hyperparameter_tuning_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.create_batch_prediction_job: gapic_v1.method.wrap_method( self.create_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_batch_prediction_job: gapic_v1.method.wrap_method( self.get_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_batch_prediction_jobs: gapic_v1.method.wrap_method( self.list_batch_prediction_jobs, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_batch_prediction_job: gapic_v1.method.wrap_method( self.delete_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.cancel_batch_prediction_job: gapic_v1.method.wrap_method( self.cancel_batch_prediction_job, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 9e6cf8c669..9f505b26b2 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -173,14 +173,14 @@ def parse_annotated_dataset_path(path: str) -> Dict[str,str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -195,14 +195,14 @@ def parse_dataset_path(path: str) -> Dict[str,str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index f549b5e68d..5ac7d67aa2 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -241,7 +241,7 @@ async def upload_model(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.upload_model, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -324,7 +324,7 @@ async def get_model(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_model, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -405,7 +405,7 @@ async def list_models(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_models, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -500,7 +500,7 @@ async def update_model(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.update_model, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -593,7 +593,7 @@ async def delete_model(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_model, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -697,7 +697,7 @@ async def export_model(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.export_model, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -786,7 +786,7 @@ async def get_model_evaluation(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_model_evaluation, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -867,7 +867,7 @@ async def list_model_evaluations(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_model_evaluations, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -957,7 +957,7 @@ async def get_model_evaluation_slice(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_model_evaluation_slice, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -1039,7 +1039,7 @@ async def list_model_evaluation_slices(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_model_evaluation_slices, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1/services/model_service/transports/base.py index 80c34f3e4a..262cb1c736 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/base.py @@ -112,52 +112,52 @@ def _prep_wrapped_messages(self, client_info): self._wrapped_methods = { self.upload_model: gapic_v1.method.wrap_method( self.upload_model, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_model: gapic_v1.method.wrap_method( self.get_model, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_models: gapic_v1.method.wrap_method( self.list_models, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.update_model: gapic_v1.method.wrap_method( self.update_model, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_model: gapic_v1.method.wrap_method( self.delete_model, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.export_model: gapic_v1.method.wrap_method( self.export_model, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_model_evaluation: gapic_v1.method.wrap_method( self.get_model_evaluation, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_model_evaluations: gapic_v1.method.wrap_method( self.list_model_evaluations, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_model_evaluation_slice: gapic_v1.method.wrap_method( self.get_model_evaluation_slice, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_model_evaluation_slices: gapic_v1.method.wrap_method( self.list_model_evaluation_slices, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index 3b43bc080c..67d3fda8cb 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -238,7 +238,7 @@ async def create_training_pipeline(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -321,7 +321,7 @@ async def get_training_pipeline(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -402,7 +402,7 @@ async def list_training_pipelines(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_training_pipelines, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -503,7 +503,7 @@ async def delete_training_pipeline(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -596,7 +596,7 @@ async def cancel_training_pipeline(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.cancel_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py index 3a0cfa5a08..962fe14c76 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py @@ -111,27 +111,27 @@ def _prep_wrapped_messages(self, client_info): self._wrapped_methods = { self.create_training_pipeline: gapic_v1.method.wrap_method( self.create_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_training_pipeline: gapic_v1.method.wrap_method( self.get_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_training_pipelines: gapic_v1.method.wrap_method( self.list_training_pipelines, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_training_pipeline: gapic_v1.method.wrap_method( self.delete_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.cancel_training_pipeline: gapic_v1.method.wrap_method( self.cancel_training_pipeline, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index 299694bdce..753202046e 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -241,7 +241,7 @@ async def predict(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.predict, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py index 9e8a9841c0..ebba095d37 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -106,7 +106,7 @@ def _prep_wrapped_messages(self, client_info): self._wrapped_methods = { self.predict: gapic_v1.method.wrap_method( self.predict, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index c05ca17005..22a7365041 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -238,7 +238,7 @@ async def create_specialist_pool(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -334,7 +334,7 @@ async def get_specialist_pool(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.get_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -415,7 +415,7 @@ async def list_specialist_pools(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.list_specialist_pools, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -516,7 +516,7 @@ async def delete_specialist_pool(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) @@ -619,7 +619,7 @@ async def update_specialist_pool(self, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( self._client._transport.update_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=DEFAULT_CLIENT_INFO, ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py index 878e095edb..e05bc7d77c 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py @@ -109,27 +109,27 @@ def _prep_wrapped_messages(self, client_info): self._wrapped_methods = { self.create_specialist_pool: gapic_v1.method.wrap_method( self.create_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.get_specialist_pool: gapic_v1.method.wrap_method( self.get_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.list_specialist_pools: gapic_v1.method.wrap_method( self.list_specialist_pools, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.delete_specialist_pool: gapic_v1.method.wrap_method( self.delete_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), self.update_specialist_pool: gapic_v1.method.wrap_method( self.update_specialist_pool, - default_timeout=None, + default_timeout=5.0, client_info=client_info, ), diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py index e671bbfa1c..99fdb689e3 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -936,7 +936,7 @@ async def update_entity_type(self, Args: request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateEntityTypeRequest`): The request object. Request message for - [FeaturestoreService.UpdateEntityTypes][]. + ``FeaturestoreService.UpdateEntityType``. entity_type (:class:`google.cloud.aiplatform_v1beta1.types.EntityType`): Required. The EntityType's ``name`` field is used to identify the EntityType to be updated. Format: @@ -959,6 +959,8 @@ async def update_entity_type(self, - ``description`` - ``labels`` + - ``monitoring_config.snapshot_analysis.disabled`` + - ``monitoring_config.snapshot_analysis.monitoring_interval`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index c566a9b24e..1bef3bb531 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -1134,7 +1134,7 @@ def update_entity_type(self, Args: request (google.cloud.aiplatform_v1beta1.types.UpdateEntityTypeRequest): The request object. Request message for - [FeaturestoreService.UpdateEntityTypes][]. + ``FeaturestoreService.UpdateEntityType``. entity_type (google.cloud.aiplatform_v1beta1.types.EntityType): Required. The EntityType's ``name`` field is used to identify the EntityType to be updated. Format: @@ -1157,6 +1157,8 @@ def update_entity_type(self, - ``description`` - ``labels`` + - ``monitoring_config.snapshot_analysis.disabled`` + - ``monitoring_config.snapshot_analysis.monitoring_interval`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 1c08ffef30..3d57aa5c1f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -173,25 +173,25 @@ def parse_annotated_dataset_path(path: str) -> Dict[str,str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str,str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py index c5d1b4034f..1844d2ac15 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py @@ -485,7 +485,9 @@ class FeatureValueDestination(proto.Message): Attributes: bigquery_destination (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): - Output in BigQuery format. output_uri in + Output in BigQuery format. + ``BigQueryDestination.output_uri`` + in ``FeatureValueDestination.bigquery_destination`` must refer to a table. tfrecord_destination (google.cloud.aiplatform_v1beta1.types.TFRecordDestination): @@ -669,7 +671,8 @@ def raw_page(self): class UpdateEntityTypeRequest(proto.Message): - r"""Request message for [FeaturestoreService.UpdateEntityTypes][]. + r"""Request message for + ``FeaturestoreService.UpdateEntityType``. Attributes: entity_type (google.cloud.aiplatform_v1beta1.types.EntityType): @@ -690,6 +693,8 @@ class UpdateEntityTypeRequest(proto.Message): - ``description`` - ``labels`` + - ``monitoring_config.snapshot_analysis.disabled`` + - ``monitoring_config.snapshot_analysis.monitoring_interval`` """ entity_type = proto.Field(proto.MESSAGE, number=1, @@ -738,7 +743,7 @@ class CreateFeatureRequest(proto.Message): This value may be up to 60 characters, and valid characters are ``[a-z0-9_]``. The first character cannot be a number. - The value must be unique within an entitytype. + The value must be unique within an EntityType. """ parent = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index 3a3abf1b06..72e3e24e7a 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -130,7 +130,7 @@ class CsvDestination(proto.Message): Attributes: gcs_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): - Google Cloud Storage location. + Required. Google Cloud Storage location. """ gcs_destination = proto.Field(proto.MESSAGE, number=1, @@ -143,7 +143,7 @@ class TFRecordDestination(proto.Message): Attributes: gcs_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): - Google Cloud Storage location. + Required. Google Cloud Storage location. """ gcs_destination = proto.Field(proto.MESSAGE, number=1, diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py new file mode 100644 index 0000000000..580c6a962d --- /dev/null +++ b/samples/model-builder/conftest.py @@ -0,0 +1,205 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +from google.cloud import aiplatform +import pytest + + +@pytest.fixture +def mock_sdk_init(): + with patch.object(aiplatform, "init") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Dataset Fixtures +---------------------------------------------------------------------------- +""" + +"""Dataset objects returned by SomeDataset(), create(), import_data(), etc. """ + + +@pytest.fixture +def mock_image_dataset(): + mock = MagicMock(aiplatform.datasets.ImageDataset) + yield mock + + +@pytest.fixture +def mock_tabular_dataset(): + mock = MagicMock(aiplatform.datasets.TabularDataset) + yield mock + + +@pytest.fixture +def mock_text_dataset(): + mock = MagicMock(aiplatform.datasets.TextDataset) + yield mock + + +@pytest.fixture +def mock_video_dataset(): + mock = MagicMock(aiplatform.datasets.VideoDataset) + yield mock + + +"""Mocks for getting an existing Dataset, i.e. ds = aiplatform.ImageDataset(...) """ + + +@pytest.fixture +def mock_get_image_dataset(mock_image_dataset): + with patch.object(aiplatform, "ImageDataset") as mock_get_image_dataset: + mock_get_image_dataset.return_value = mock_image_dataset + yield mock_get_image_dataset + + +@pytest.fixture +def mock_get_tabular_dataset(mock_tabular_dataset): + with patch.object(aiplatform, "TabularDataset") as mock_get_tabular_dataset: + mock_get_tabular_dataset.return_value = mock_tabular_dataset + yield mock_get_tabular_dataset + + +@pytest.fixture +def mock_get_text_dataset(mock_text_dataset): + with patch.object(aiplatform, "TextDataset") as mock_get_text_dataset: + mock_get_text_dataset.return_value = mock_text_dataset + yield mock_get_text_dataset + + +@pytest.fixture +def mock_get_video_dataset(mock_video_dataset): + with patch.object(aiplatform, "VideoDataset") as mock_get_video_dataset: + mock_get_video_dataset.return_value = mock_video_dataset + yield mock_get_video_dataset + + +"""Mocks for creating a new Dataset, i.e. aiplatform.ImageDataset.create(...) """ + + +@pytest.fixture +def mock_create_image_dataset(mock_image_dataset): + with patch.object(aiplatform.ImageDataset, "create") as mock_create_image_dataset: + mock_create_image_dataset.return_value = mock_image_dataset + yield mock_create_image_dataset + + +@pytest.fixture +def mock_create_tabular_dataset(mock_tabular_dataset): + with patch.object( + aiplatform.TabularDataset, "create" + ) as mock_create_tabular_dataset: + mock_create_tabular_dataset.return_value = mock_tabular_dataset + yield mock_create_tabular_dataset + + +@pytest.fixture +def mock_create_text_dataset(mock_text_dataset): + with patch.object(aiplatform.TextDataset, "create") as mock_create_text_dataset: + mock_create_text_dataset.return_value = mock_text_dataset + yield mock_create_text_dataset + + +@pytest.fixture +def mock_create_video_dataset(mock_video_dataset): + with patch.object(aiplatform.VideoDataset, "create") as mock_create_video_dataset: + mock_create_video_dataset.return_value = mock_video_dataset + yield mock_create_video_dataset + + +"""Mocks for SomeDataset.import_data() """ + + +@pytest.fixture +def mock_import_text_dataset(mock_text_dataset): + with patch.object(mock_text_dataset, "import_data") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +TrainingJob Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_init_automl_image_training_job(): + with patch.object( + aiplatform.training_jobs.AutoMLImageTrainingJob, "__init__" + ) as mock: + mock.return_value = None + yield mock + + +@pytest.fixture +def mock_run_automl_image_training_job(): + with patch.object(aiplatform.training_jobs.AutoMLImageTrainingJob, "run") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Model Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_init_model(): + with patch.object(aiplatform.models.Model, "__init__") as mock: + mock.return_value = None + yield mock + + +@pytest.fixture +def mock_batch_predict_model(): + with patch.object(aiplatform.models.Model, "batch_predict") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Job Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_create_batch_prediction_job(): + with patch.object(aiplatform.jobs.BatchPredictionJob, "create") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Endpoint Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_endpoint(): + mock = MagicMock(aiplatform.models.Endpoint) + yield mock + + +@pytest.fixture +def mock_get_endpoint(mock_endpoint): + with patch.object(aiplatform, "Endpoint") as mock_get_endpoint: + mock_get_endpoint.return_value = mock_endpoint + yield mock_get_endpoint diff --git a/samples/model-builder/create_and_import_dataset_image_sample.py b/samples/model-builder/create_and_import_dataset_image_sample.py new file mode 100644 index 0000000000..bab7c8a59c --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_image_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_image_sample] +def create_and_import_dataset_image_sample( + project: str, + location: str, + display_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.image.single_label_classification, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_create_and_import_dataset_image_sample] diff --git a/samples/model-builder/create_and_import_dataset_image_sample_test.py b/samples/model-builder/create_and_import_dataset_image_sample_test.py new file mode 100644 index 0000000000..6991ff3a13 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_image_sample_test.py @@ -0,0 +1,41 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import create_and_import_dataset_image_sample +import test_constants as constants + + +def test_create_and_import_dataset_image_sample( + mock_sdk_init, mock_create_image_dataset +): + + create_and_import_dataset_image_sample.create_and_import_dataset_image_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.image.single_label_classification, + sync=True, + ) diff --git a/samples/model-builder/create_and_import_dataset_text_sample.py b/samples/model-builder/create_and_import_dataset_text_sample.py new file mode 100644 index 0000000000..e3321020bf --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_text_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_text_sample] +def create_and_import_dataset_text_sample( + project: str, + location: str, + display_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.single_label_classification, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_create_and_import_dataset_text_sample] diff --git a/samples/model-builder/create_and_import_dataset_text_sample_test.py b/samples/model-builder/create_and_import_dataset_text_sample_test.py new file mode 100644 index 0000000000..e41082d06f --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_text_sample_test.py @@ -0,0 +1,39 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import create_and_import_dataset_text_sample +import test_constants as constants + + +def test_create_and_import_dataset_text_sample(mock_sdk_init, mock_create_text_dataset): + + create_and_import_dataset_text_sample.create_and_import_dataset_text_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_text_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.single_label_classification, + sync=True, + ) diff --git a/samples/model-builder/create_batch_prediction_job_sample.py b/samples/model-builder/create_batch_prediction_job_sample.py new file mode 100644 index 0000000000..9bd5c697a5 --- /dev/null +++ b/samples/model-builder/create_batch_prediction_job_sample.py @@ -0,0 +1,49 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_batch_prediction_job_sample] +def create_batch_prediction_job_sample( + project: str, + location: str, + model_resource_name: str, + job_display_name: str, + gcs_source: Union[str, Sequence[str]], + gcs_destination: str, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + my_model = aiplatform.Model(model_resource_name) + + batch_prediction_job = my_model.batch_predict( + job_display_name=job_display_name, + gcs_source=gcs_source, + gcs_destination_prefix=gcs_destination, + sync=sync, + ) + + batch_prediction_job.wait() + + print(batch_prediction_job.display_name) + print(batch_prediction_job.resource_name) + print(batch_prediction_job.state) + return batch_prediction_job + + +# [END aiplatform_sdk_create_batch_prediction_job_sample] diff --git a/samples/model-builder/create_batch_prediction_job_sample_test.py b/samples/model-builder/create_batch_prediction_job_sample_test.py new file mode 100644 index 0000000000..f39c1020b5 --- /dev/null +++ b/samples/model-builder/create_batch_prediction_job_sample_test.py @@ -0,0 +1,42 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_batch_prediction_job_sample +import test_constants as constants + + +def test_create_batch_prediction_job_sample( + mock_sdk_init, mock_init_model, mock_batch_predict_model +): + + create_batch_prediction_job_sample.create_batch_prediction_job_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_resource_name=constants.MODEL_NAME, + job_display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + gcs_destination=constants.GCS_DESTINATION, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_init_model.assert_called_once_with(constants.MODEL_NAME) + mock_batch_predict_model.assert_called_once_with( + job_display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + gcs_destination_prefix=constants.GCS_DESTINATION, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample.py b/samples/model-builder/create_training_pipeline_image_classification_sample.py new file mode 100644 index 0000000000..050d40af82 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_image_classification_sample.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_image_classification_sample] +def create_training_pipeline_image_classification_sample( + project: str, + display_name: str, + dataset_id: int, + location: str = "us-central1", + model_display_name: str = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLImageTrainingJob(display_name=display_name) + + my_image_ds = aiplatform.ImageDataset(dataset_id) + + model = job.run( + dataset=my_image_ds, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_image_classification_sample] diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py new file mode 100644 index 0000000000..c49e0e5f05 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_image_classification_sample +import test_constants as constants + + +def test_create_training_pipeline_image_classification_sample( + mock_sdk_init, + mock_image_dataset, + mock_init_automl_image_training_job, + mock_run_automl_image_training_job, + mock_get_image_dataset, +): + + create_training_pipeline_image_classification_sample.create_training_pipeline_image_classification_sample( + project=constants.PROJECT, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_image_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_init_automl_image_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME + ) + mock_run_automl_image_training_job.assert_called_once_with( + dataset=mock_image_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/import_data_text_classification_single_label_sample.py b/samples/model-builder/import_data_text_classification_single_label_sample.py new file mode 100644 index 0000000000..c63cc3f1d1 --- /dev/null +++ b/samples/model-builder/import_data_text_classification_single_label_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_text_classification_single_label_sample] +def import_data_text_classification_single_label( + project: str, + location: str, + dataset: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset(dataset) + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.single_label_classification, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_text_classification_single_label_sample] diff --git a/samples/model-builder/import_data_text_classification_single_label_sample_test.py b/samples/model-builder/import_data_text_classification_single_label_sample_test.py new file mode 100644 index 0000000000..1765ab013e --- /dev/null +++ b/samples/model-builder/import_data_text_classification_single_label_sample_test.py @@ -0,0 +1,43 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import import_data_text_classification_single_label_sample +import test_constants as constants + + +def test_import_data_text_classification_single_label_sample( + mock_sdk_init, mock_get_text_dataset, mock_import_text_dataset +): + + import_data_text_classification_single_label_sample.import_data_text_classification_single_label( + project=constants.PROJECT, + location=constants.LOCATION, + dataset=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_text_dataset.assert_called_once_with(constants.DATASET_NAME) + + mock_import_text_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.single_label_classification, + sync=True, + ) diff --git a/samples/model-builder/import_data_text_entity_extraction_sample.py b/samples/model-builder/import_data_text_entity_extraction_sample.py new file mode 100644 index 0000000000..7e00d57632 --- /dev/null +++ b/samples/model-builder/import_data_text_entity_extraction_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_text_entity_extraction_sample] +def import_data_text_entity_extraction_sample( + project: str, + location: str, + dataset: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset(dataset) + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.extraction, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_text_entity_extraction_sample] diff --git a/samples/model-builder/import_data_text_entity_extraction_sample_test.py b/samples/model-builder/import_data_text_entity_extraction_sample_test.py new file mode 100644 index 0000000000..a3b93e9200 --- /dev/null +++ b/samples/model-builder/import_data_text_entity_extraction_sample_test.py @@ -0,0 +1,45 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import import_data_text_entity_extraction_sample +import test_constants as constants + + +def test_import_data_text_entity_extraction_sample( + mock_sdk_init, mock_get_text_dataset, mock_import_text_dataset +): + + import_data_text_entity_extraction_sample.import_data_text_entity_extraction_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_text_dataset.assert_called_once_with( + constants.DATASET_NAME, + ) + + mock_import_text_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.extraction, + sync=True, + ) diff --git a/samples/model-builder/import_data_text_sentiment_analysis_sample.py b/samples/model-builder/import_data_text_sentiment_analysis_sample.py new file mode 100644 index 0000000000..3861a1102a --- /dev/null +++ b/samples/model-builder/import_data_text_sentiment_analysis_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_text_sentiment_analysis_sample] +def import_data_text_sentiment_analysis_sample( + project: str, + location: str, + dataset: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset(dataset) + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.sentiment, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_text_sentiment_analysis_sample] diff --git a/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py b/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py new file mode 100644 index 0000000000..2134d66b35 --- /dev/null +++ b/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py @@ -0,0 +1,45 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import import_data_text_sentiment_analysis_sample +import test_constants as constants + + +def test_import_data_text_sentiment_analysis_sample( + mock_sdk_init, mock_get_text_dataset, mock_import_text_dataset +): + + import_data_text_sentiment_analysis_sample.import_data_text_sentiment_analysis_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_text_dataset.assert_called_once_with( + constants.DATASET_NAME, + ) + + mock_import_text_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.sentiment, + sync=True, + ) diff --git a/samples/model-builder/init_sample.py b/samples/model-builder/init_sample.py new file mode 100644 index 0000000000..8ced169ec4 --- /dev/null +++ b/samples/model-builder/init_sample.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.auth import credentials as auth_credentials +from google.cloud import aiplatform + + +# [START aiplatform_sdk_init_sample] +def init_sample( + project: Optional[str] = None, + location: Optional[str] = None, + experiment: Optional[str] = None, + staging_bucket: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, +): + aiplatform.init( + project=project, + location=location, + experiment=experiment, + staging_bucket=staging_bucket, + credentials=credentials, + encryption_spec_key_name=encryption_spec_key_name, + ) + + +# [END aiplatform_sdk_init_sample] diff --git a/samples/model-builder/init_sample_test.py b/samples/model-builder/init_sample_test.py new file mode 100644 index 0000000000..3c4684a255 --- /dev/null +++ b/samples/model-builder/init_sample_test.py @@ -0,0 +1,38 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import init_sample +import test_constants as constants + + +def test_init_sample(mock_sdk_init): + + init_sample.init_sample( + project=constants.PROJECT, + location=constants.LOCATION_EUROPE, + experiment=constants.EXPERIMENT_NAME, + staging_bucket=constants.STAGING_BUCKET, + credentials=constants.CREDENTIALS, + encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, + location=constants.LOCATION_EUROPE, + experiment=constants.EXPERIMENT_NAME, + staging_bucket=constants.STAGING_BUCKET, + credentials=constants.CREDENTIALS, + encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME, + ) diff --git a/samples/model-builder/noxfile.py b/samples/model-builder/noxfile.py new file mode 100644 index 0000000000..83bf446de2 --- /dev/null +++ b/samples/model-builder/noxfile.py @@ -0,0 +1,221 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +from pathlib import Path +import sys + +import nox + +# WARNING - WARNING - WARNING - WARNING - WARNING +# WARNING - WARNING - WARNING - WARNING - WARNING +# DO NOT EDIT THIS FILE EVER! +# WARNING - WARNING - WARNING - WARNING - WARNING +# WARNING - WARNING - WARNING - WARNING - WARNING + +# Copy `noxfile_config.py` to your directory and modify it instead. + + +# `TEST_CONFIG` dict is a configuration hook that allows users to +# modify the test configurations. The values here should be in sync +# with `noxfile_config.py`. Users will copy `noxfile_config.py` into +# their directory and modify it. + +TEST_CONFIG = { + # You can opt out from the test for specific Python versions. + "ignored_versions": ["2.7"], + # An envvar key for determining the project id to use. Change it + # to 'BUILD_SPECIFIC_GCLOUD_PROJECT' if you want to opt in using a + # build specific Cloud project. You can also use your own string + # to use your own Cloud project. + "gcloud_project_env": "GOOGLE_CLOUD_PROJECT", + # 'gcloud_project_env': 'BUILD_SPECIFIC_GCLOUD_PROJECT', + # A dictionary you want to inject into your test. Don't put any + # secrets here. These values will override predefined values. + "envs": {}, +} + + +try: + # Ensure we can import noxfile_config in the project's directory. + sys.path.append(".") + from noxfile_config import TEST_CONFIG_OVERRIDE +except ImportError as e: + print("No user noxfile_config found: detail: {}".format(e)) + TEST_CONFIG_OVERRIDE = {} + +# Update the TEST_CONFIG with the user supplied values. +TEST_CONFIG.update(TEST_CONFIG_OVERRIDE) + + +def get_pytest_env_vars(): + """Returns a dict for pytest invocation.""" + ret = {} + + # Override the GCLOUD_PROJECT and the alias. + env_key = TEST_CONFIG["gcloud_project_env"] + # This should error out if not set. + ret["GOOGLE_CLOUD_PROJECT"] = os.environ[env_key] + + # Apply user supplied envs. + ret.update(TEST_CONFIG["envs"]) + return ret + + +# DO NOT EDIT - automatically generated. +# All versions used to tested samples. +ALL_VERSIONS = ["2.7", "3.6", "3.7", "3.8"] + +# Any default versions that should be ignored. +IGNORED_VERSIONS = TEST_CONFIG["ignored_versions"] + +TESTED_VERSIONS = sorted([v for v in ALL_VERSIONS if v not in IGNORED_VERSIONS]) + +INSTALL_LIBRARY_FROM_SOURCE = bool(os.environ.get("INSTALL_LIBRARY_FROM_SOURCE", False)) +# +# Style Checks +# + + +def _determine_local_import_names(start_dir): + """Determines all import names that should be considered "local". + + This is used when running the linter to insure that import order is + properly checked. + """ + file_ext_pairs = [os.path.splitext(path) for path in os.listdir(start_dir)] + return [ + basename + for basename, extension in file_ext_pairs + if extension == ".py" + or os.path.isdir(os.path.join(start_dir, basename)) + and basename not in ("__pycache__") + ] + + +# Linting with flake8. +# +# We ignore the following rules: +# E203: whitespace before ‘:’ +# E266: too many leading ‘#’ for block comment +# E501: line too long +# I202: Additional newline in a section of imports +# +# We also need to specify the rules which are ignored by default: +# ['E226', 'W504', 'E126', 'E123', 'W503', 'E24', 'E704', 'E121'] +FLAKE8_COMMON_ARGS = [ + "--show-source", + "--builtin=gettext", + "--max-complexity=20", + "--import-order-style=google", + "--exclude=.nox,.cache,env,lib,generated_pb2,*_pb2.py,*_pb2_grpc.py", + "--ignore=E121,E123,E126,E203,E226,E24,E266,E501,E704,W503,W504,I202", + "--max-line-length=88", +] + + +@nox.session +def lint(session): + session.install("flake8", "flake8-import-order") + + local_names = _determine_local_import_names(".") + args = FLAKE8_COMMON_ARGS + [ + "--application-import-names", + ",".join(local_names), + ".", + ] + session.run("flake8", *args) + + +# +# Sample Tests +# + + +PYTEST_COMMON_ARGS = ["--junitxml=sponge_log.xml"] + + +def _session_tests(session, post_install=None): + """Runs py.test for a particular project.""" + if os.path.exists("requirements.txt"): + session.install("-r", "requirements.txt") + + if os.path.exists("requirements-test.txt"): + session.install("-r", "requirements-test.txt") + + if INSTALL_LIBRARY_FROM_SOURCE: + session.install("-e", _get_repo_root()) + + if post_install: + post_install(session) + + session.run( + "pytest", + *(PYTEST_COMMON_ARGS + session.posargs), + # Pytest will return 5 when no tests are collected. This can happen + # on travis where slow and flaky tests are excluded. + # See http://doc.pytest.org/en/latest/_modules/_pytest/main.html + success_codes=[0, 5], + env=get_pytest_env_vars() + ) + + +@nox.session(python=ALL_VERSIONS) +def py(session): + """Runs py.test for a sample using the specified version of Python.""" + if session.python in TESTED_VERSIONS: + _session_tests(session) + else: + session.skip( + "SKIPPED: {} tests are disabled for this sample.".format(session.python) + ) + + +# +# Readmegen +# + + +def _get_repo_root(): + """ Returns the root folder of the project. """ + # Get root of this repository. Assume we don't have directories nested deeper than 10 items. + p = Path(os.getcwd()) + for i in range(10): + if p is None: + break + if Path(p / ".git").exists(): + return str(p) + p = p.parent + raise Exception("Unable to detect repository root.") + + +GENERATED_READMES = sorted([x for x in Path(".").rglob("*.rst.in")]) + + +@nox.session +@nox.parametrize("path", GENERATED_READMES) +def readmegen(session, path): + """(Re-)generates the readme for a sample.""" + session.install("jinja2", "pyyaml") + dir_ = os.path.dirname(path) + + if os.path.exists(os.path.join(dir_, "requirements.txt")): + session.install("-r", os.path.join(dir_, "requirements.txt")) + + in_file = os.path.join(dir_, "README.rst.in") + session.run( + "python", _get_repo_root() + "/scripts/readme-gen/readme_gen.py", in_file + ) diff --git a/samples/model-builder/predict_text_classification_single_label_sample.py b/samples/model-builder/predict_text_classification_single_label_sample.py new file mode 100644 index 0000000000..195b519750 --- /dev/null +++ b/samples/model-builder/predict_text_classification_single_label_sample.py @@ -0,0 +1,33 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_text_classification_single_label_sample] +def predict_text_classification_single_label_sample( + project, location, endpoint, content +): + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint) + + response = endpoint.predict(instances=[{"content": content}], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_text_classification_single_label_sample] diff --git a/samples/model-builder/predict_text_classification_single_label_sample_test.py b/samples/model-builder/predict_text_classification_single_label_sample_test.py new file mode 100644 index 0000000000..c446235a79 --- /dev/null +++ b/samples/model-builder/predict_text_classification_single_label_sample_test.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_text_classification_single_label_sample +import test_constants as constants + + +def test_predict_text_classification_single_label_sample( + mock_sdk_init, mock_get_endpoint +): + + predict_text_classification_single_label_sample.predict_text_classification_single_label_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint=constants.ENDPOINT_NAME, + content=constants.PREDICTION_TEXT_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) diff --git a/samples/model-builder/predict_text_entity_extraction_sample.py b/samples/model-builder/predict_text_entity_extraction_sample.py new file mode 100644 index 0000000000..577296333a --- /dev/null +++ b/samples/model-builder/predict_text_entity_extraction_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_text_entity_extraction_sample] +def predict_text_entity_extraction_sample(project, location, endpoint_id, content): + + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint_id) + + response = endpoint.predict(instances=[{"content": content}], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_text_entity_extraction_sample] diff --git a/samples/model-builder/predict_text_entity_extraction_sample_test.py b/samples/model-builder/predict_text_entity_extraction_sample_test.py new file mode 100644 index 0000000000..3ca2b49b43 --- /dev/null +++ b/samples/model-builder/predict_text_entity_extraction_sample_test.py @@ -0,0 +1,35 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_text_entity_extraction_sample +import test_constants as constants + + +def test_predict_text_entity_extraction_sample(mock_sdk_init, mock_get_endpoint): + + predict_text_entity_extraction_sample.predict_text_entity_extraction_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint_id=constants.ENDPOINT_NAME, + content=constants.PREDICTION_TEXT_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) diff --git a/samples/model-builder/predict_text_sentiment_analysis_sample.py b/samples/model-builder/predict_text_sentiment_analysis_sample.py new file mode 100644 index 0000000000..9fca0b4168 --- /dev/null +++ b/samples/model-builder/predict_text_sentiment_analysis_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_text_sentiment_analysis_sample] +def predict_text_sentiment_analysis_sample(project, location, endpoint_id, content): + + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint_id) + + response = endpoint.predict(instances=[{"content": content}], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_text_sentiment_analysis_sample] diff --git a/samples/model-builder/predict_text_sentiment_analysis_sample_test.py b/samples/model-builder/predict_text_sentiment_analysis_sample_test.py new file mode 100644 index 0000000000..c2ed180c9f --- /dev/null +++ b/samples/model-builder/predict_text_sentiment_analysis_sample_test.py @@ -0,0 +1,35 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_text_sentiment_analysis_sample +import test_constants as constants + + +def test_predict_text_sentiment_analysis_sample(mock_sdk_init, mock_get_endpoint): + + predict_text_sentiment_analysis_sample.predict_text_sentiment_analysis_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint_id=constants.ENDPOINT_NAME, + content=constants.PREDICTION_TEXT_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) diff --git a/samples/model-builder/requirements-tests.txt b/samples/model-builder/requirements-tests.txt new file mode 100644 index 0000000000..f53c4c11a6 --- /dev/null +++ b/samples/model-builder/requirements-tests.txt @@ -0,0 +1 @@ +pytest >= 6.2 diff --git a/samples/model-builder/requirements.txt b/samples/model-builder/requirements.txt new file mode 100644 index 0000000000..efe811b2c3 --- /dev/null +++ b/samples/model-builder/requirements.txt @@ -0,0 +1,2 @@ +pytest >= 6.2 +git+https://github.com/googleapis/python-aiplatform.git@mb-release#egg=google-cloud-aiplatform \ No newline at end of file diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py new file mode 100644 index 0000000000..50dfa968b4 --- /dev/null +++ b/samples/model-builder/test_constants.py @@ -0,0 +1,53 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from random import randint +from uuid import uuid4 + +from google.auth import credentials + +PROJECT = "abc" +LOCATION = "us-central1" +LOCATION_EUROPE = "europe-west4" +LOCATION_ASIA = "asia-east1" +PARENT = f"projects/{PROJECT}/locations/{LOCATION}" + +DISPLAY_NAME = str(uuid4()) # Create random display name +DISPLAY_NAME_2 = str(uuid4()) + +STAGING_BUCKET = "gs://my-staging-bucket" +EXPERIMENT_NAME = "fraud-detection-trial-72" +CREDENTIALS = credentials.AnonymousCredentials() + +RESOURCE_ID = str(randint(10000000, 99999999)) # Create random resource ID +RESOURCE_ID_2 = str(randint(10000000, 99999999)) + +BATCH_PREDICTION_JOB_NAME = f"{PARENT}/batchPredictionJobs/{RESOURCE_ID}" +DATASET_NAME = f"{PARENT}/datasets/{RESOURCE_ID}" +ENDPOINT_NAME = f"{PARENT}/endpoints/{RESOURCE_ID}" +MODEL_NAME = f"{PARENT}/models/{RESOURCE_ID}" +TRAINING_JOB_NAME = f"{PARENT}/trainingJobs/{RESOURCE_ID}" + +GCS_SOURCES = ["gs://bucket1/source1.jsonl", "gs://bucket7/source4.jsonl"] +GCS_DESTINATION = "gs://bucket3/output-dir/" + +TRAINING_FRACTION_SPLIT = 0.7 +TEST_FRACTION_SPLIT = 0.15 +VALIDATION_FRACTION_SPLIT = 0.15 + +BUDGET_MILLI_NODE_HOURS_8000 = 8000 + +ENCRYPTION_SPEC_KEY_NAME = f"{PARENT}/keyRings/{RESOURCE_ID}/cryptoKeys/{RESOURCE_ID_2}" + +PREDICTION_TEXT_INSTANCE = "This is some text for testing NLP prediction output" diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index 481213275f..8a6680087b 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,3 +1,3 @@ pytest==6.2.2 google-cloud-storage>=1.26.0, <2.0.0dev -google-cloud-aiplatform==0.6.0 +google-cloud-aiplatform==0.7.1 diff --git a/setup.py b/setup.py index cc19d7a867..b89d2a6417 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ import setuptools # type: ignore name = "google-cloud-aiplatform" -version = "0.6.0" +version = "0.7.1" description = "Cloud AI Platform API client library" package_root = os.path.abspath(os.path.dirname(__file__)) @@ -46,6 +46,7 @@ "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", "proto-plus >= 1.10.1", "google-cloud-storage >= 1.26.0, < 2.0.0dev", + "google-cloud-bigquery >= 1.15.0, < 3.0.0dev", ), python_requires=">=3.6", scripts=[], diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py new file mode 100644 index 0000000000..e18390a76a --- /dev/null +++ b/tests/system/aiplatform/test_dataset.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import uuid +import pytest +import importlib + +from google import auth as google_auth +from google.protobuf import json_format +from google.api_core import exceptions +from google.api_core import client_options + +from google.cloud import storage +from google.cloud import aiplatform +from google.cloud.aiplatform import utils +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.services import dataset_service + +from test_utils.vpcsc_config import vpcsc_config + +# TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported +_, _TEST_PROJECT = google_auth.default() +TEST_BUCKET = os.environ.get( + "GCLOUD_TEST_SAMPLES_BUCKET", "cloud-samples-data-us-central1" +) + +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com" +_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset +_TEST_TEXT_DATASET_ID = ( + "6203215905493614592" # permanent_text_entity_extraction_dataset +) +_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset" +_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv" +_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE = f"gs://{TEST_BUCKET}/ai-platform-unified/sdk/datasets/text_entity_extraction_dataset.jsonl" +_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE = ( + "gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl" +) +_TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml" +_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" + + +class TestDataset: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + @pytest.fixture() + def shared_state(self): + shared_state = {} + yield shared_state + + @pytest.fixture() + def create_staging_bucket(self, shared_state): + new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}" + + storage_client = storage.Client() + storage_client.create_bucket(new_staging_bucket) + shared_state["storage_client"] = storage_client + shared_state["staging_bucket"] = new_staging_bucket + yield + + @pytest.fixture() + def delete_staging_bucket(self, shared_state): + yield + storage_client = shared_state["storage_client"] + + # Delete temp staging bucket + bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"]) + bucket_to_delete.delete(force=True) + + # Close Storage Client + storage_client._http._auth_request.session.close() + storage_client._http.close() + + @pytest.fixture() + def dataset_gapic_client(self): + gapic_client = dataset_service.DatasetServiceClient( + client_options=client_options.ClientOptions(api_endpoint=_TEST_API_ENDPOINT) + ) + + yield gapic_client + + @pytest.fixture() + def create_text_dataset(self, dataset_gapic_client, shared_state): + + gapic_dataset = gca_dataset.Dataset( + display_name=f"temp_sdk_integration_test_create_text_dataset_{uuid.uuid4()}", + metadata_schema_uri=aiplatform.schema.dataset.metadata.text, + ) + + create_lro = dataset_gapic_client.create_dataset( + parent=_TEST_PARENT, dataset=gapic_dataset + ) + new_dataset = create_lro.result() + shared_state["dataset_name"] = new_dataset.name + yield + + @pytest.fixture() + def create_tabular_dataset(self, dataset_gapic_client, shared_state): + + gapic_dataset = gca_dataset.Dataset( + display_name=f"temp_sdk_integration_test_create_tabular_dataset_{uuid.uuid4()}", + metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular, + ) + + create_lro = dataset_gapic_client.create_dataset( + parent=_TEST_PARENT, dataset=gapic_dataset + ) + new_dataset = create_lro.result() + shared_state["dataset_name"] = new_dataset.name + yield + + @pytest.fixture() + def create_image_dataset(self, dataset_gapic_client, shared_state): + + gapic_dataset = gca_dataset.Dataset( + display_name=f"temp_sdk_integration_test_create_image_dataset_{uuid.uuid4()}", + metadata_schema_uri=aiplatform.schema.dataset.metadata.image, + ) + + create_lro = dataset_gapic_client.create_dataset( + parent=_TEST_PARENT, dataset=gapic_dataset + ) + new_dataset = create_lro.result() + shared_state["dataset_name"] = new_dataset.name + yield + + @pytest.fixture() + def delete_new_dataset(self, dataset_gapic_client, shared_state): + yield + assert shared_state["dataset_name"] + + deletion_lro = dataset_gapic_client.delete_dataset( + name=shared_state["dataset_name"] + ) + deletion_lro.result() + + shared_state["dataset_name"] = None + + # TODO(vinnys): Remove pytest skip once persistent resources are accessible + @pytest.mark.skip(reason="System tests cannot access persistent test resources") + def test_get_existing_dataset(self): + """Retrieve a known existing dataset, ensure SDK successfully gets the + dataset resource.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + flowers_dataset = aiplatform.ImageDataset(dataset_name=_TEST_IMAGE_DATASET_ID) + assert flowers_dataset.name == _TEST_IMAGE_DATASET_ID + assert flowers_dataset.display_name == _TEST_DATASET_DISPLAY_NAME + + def test_get_nonexistent_dataset(self): + """Ensure attempting to retrieve a dataset that doesn't exist raises + a Google API core 404 exception.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # AI Platform service returns 404 + with pytest.raises(exceptions.NotFound): + aiplatform.ImageDataset(dataset_name="0") + + @pytest.mark.usefixtures("create_text_dataset", "delete_new_dataset") + def test_get_new_dataset_and_import(self, dataset_gapic_client, shared_state): + """Retrieve new, empty dataset and import a text dataset using import(). + Then verify data items were successfully imported.""" + + assert shared_state["dataset_name"] + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_dataset = aiplatform.TextDataset(dataset_name=shared_state["dataset_name"]) + + data_items_pre_import = dataset_gapic_client.list_data_items( + parent=my_dataset.resource_name + ) + + assert len(list(data_items_pre_import)) == 0 + + # Blocking call to import + my_dataset.import_data( + gcs_source=_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE, + import_schema_uri=_TEST_TEXT_ENTITY_IMPORT_SCHEMA, + ) + + data_items_post_import = dataset_gapic_client.list_data_items( + parent=my_dataset.resource_name + ) + + assert len(list(data_items_post_import)) == 469 + + @vpcsc_config.skip_if_inside_vpcsc + @pytest.mark.usefixtures("delete_new_dataset") + def test_create_and_import_image_dataset(self, dataset_gapic_client, shared_state): + """Use the Dataset.create() method to create a new image obj detection + dataset and import images. Then confirm images were successfully imported.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + img_dataset = aiplatform.ImageDataset.create( + display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", + gcs_source=_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE, + import_schema_uri=_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA, + ) + + shared_state["dataset_name"] = img_dataset.resource_name + + data_items_iterator = dataset_gapic_client.list_data_items( + parent=img_dataset.resource_name + ) + + assert len(list(data_items_iterator)) == 14 + + @pytest.mark.usefixtures("delete_new_dataset") + def test_create_tabular_dataset(self, dataset_gapic_client, shared_state): + """Use the Dataset.create() method to create a new tabular dataset. + Then confirm the dataset was successfully created and references GCS source.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + tabular_dataset = aiplatform.TabularDataset.create( + display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", + gcs_source=[_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE], + ) + + gapic_dataset = tabular_dataset._gca_resource + shared_state["dataset_name"] = tabular_dataset.resource_name + + gapic_metadata = json_format.MessageToDict(gapic_dataset._pb.metadata) + gcs_source_uris = gapic_metadata["inputConfig"]["gcsSource"]["uri"] + + assert len(gcs_source_uris) == 1 + assert _TEST_TABULAR_CLASSIFICATION_GCS_SOURCE == gcs_source_uris[0] + assert ( + gapic_dataset.metadata_schema_uri + == aiplatform.schema.dataset.metadata.tabular + ) + + # TODO(vinnys): Remove pytest skip once persistent resources are accessible + @pytest.mark.skip(reason="System tests cannot access persistent test resources") + @pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket") + def test_export_data(self, shared_state): + """Get an existing dataset, export data to a newly created folder in + Google Cloud Storage, then verify data was successfully exported.""" + + assert shared_state["staging_bucket"] + assert shared_state["storage_client"] + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=f"gs://{shared_state['staging_bucket']}", + ) + + text_dataset = aiplatform.TextDataset(dataset_name=_TEST_TEXT_DATASET_ID) + + exported_files = text_dataset.export_data( + output_dir=f"gs://{shared_state['staging_bucket']}" + ) + + assert len(exported_files) # Ensure at least one GCS path was returned + + exported_file = exported_files[0] + bucket, prefix = utils.extract_bucket_and_prefix_from_gcs_path(exported_file) + + storage_client = shared_state["storage_client"] + + bucket = storage_client.get_bucket(bucket) + blob = bucket.get_blob(prefix) + + assert blob # Verify the returned GCS export path exists diff --git a/tests/unit/aiplatform/test_automl_image_training_jobs.py b/tests/unit/aiplatform/test_automl_image_training_jobs.py new file mode 100644 index 0000000000..ec0de7140b --- /dev/null +++ b/tests/unit/aiplatform/test_automl_image_training_jobs.py @@ -0,0 +1,434 @@ +import pytest +import importlib +from unittest import mock + +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_IMAGE = schema.dataset.metadata.image + +_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS = 1000 +_TEST_TRAINING_DISABLE_EARLY_STOPPING = True +_TEST_MODEL_TYPE_ICN = "CLOUD" # Image Classification default +_TEST_MODEL_TYPE_IOD = "CLOUD_HIGH_ACCURACY_1" # Image Object Detection default +_TEST_MODEL_TYPE_MOBILE = "MOBILE_TF_LOW_LATENCY_1" +_TEST_PREDICTION_TYPE_ICN = "classification" +_TEST_PREDICTION_TYPE_IOD = "object_detection" + +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_MODEL_ID = "98777645321" + +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + { + "modelType": "CLOUD", + "budgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + "multiLabel": False, + "disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING, + }, + struct_pb2.Value(), +) + +_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL = json_format.ParseDict( + { + "modelType": "CLOUD", + "budgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + "multiLabel": False, + "disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING, + "baseModelId": _TEST_MODEL_ID, + }, + struct_pb2.Value(), +) + +_TEST_FRACTION_SPLIT_TRAINING = 0.6 +_TEST_FRACTION_SPLIT_VALIDATION = 0.2 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" +) + +_TEST_PIPELINE_RESOURCE_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_image(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_model_image(): + model = mock.MagicMock(models.Model) + model.name = _TEST_MODEL_ID + model._latest_future = None + model._exception = None + model._gca_resource = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description="This is the mock Model's description", + name=_TEST_MODEL_NAME, + ) + yield model + + +class TestAutoMLImageTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_all_parameters(self, mock_model_image): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_ICN, + model_type=_TEST_MODEL_TYPE_MOBILE, + base_model=mock_model_image, + multi_label=True, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert job._model_type == _TEST_MODEL_TYPE_MOBILE + assert job._prediction_type == _TEST_PREDICTION_TYPE_ICN + assert job._multi_label is True + assert job._base_model == mock_model_image + + def test_init_wrong_parameters(self, mock_model_image): + """Ensure correct exceptions are raised when initializing with invalid args""" + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(ValueError, match=r"not a supported prediction type"): + training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, prediction_type="abcdefg", + ) + + with pytest.raises(ValueError, match=r"not a supported model_type for"): + training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type="classification", + model_type=_TEST_MODEL_TYPE_IOD, + ) + + with pytest.raises(ValueError, match=r"`base_model` is only supported"): + training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_IOD, + base_model=mock_model_image, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_image, + mock_model_service_get, + mock_model_image, + sync, + ): + """Create and run an AutoML ICN training job, verify calls and return value""" + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, base_model=mock_model_image + ) + + model_from_job = job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model_image._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_image.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_image_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_image, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_image, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_image.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_image_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises(self, mock_dataset_image, sync): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_image, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_image, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state diff --git a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py new file mode 100644 index 0000000000..62cab4b3c3 --- /dev/null +++ b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py @@ -0,0 +1,441 @@ +import importlib +import pytest +from unittest import mock + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +_TEST_BUCKET_NAME = "test-bucket" +_TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder" +_TEST_GCS_PATH = f"{_TEST_BUCKET_NAME}/{_TEST_GCS_PATH_WITHOUT_BUCKET}" +_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/" +_TEST_PROJECT = "test-project" + +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" +_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular +_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image + +_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [ + {"auto": {"column_name": "sepal_width"}}, + {"auto": {"column_name": "sepal_length"}}, + {"auto": {"column_name": "petal_length"}}, + {"auto": {"column_name": "petal_width"}}, +] +_TEST_TRAINING_TARGET_COLUMN = "target" +_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS = 1000 +_TEST_TRAINING_WEIGHT_COLUMN = "weight" +_TEST_TRAINING_DISABLE_EARLY_STOPPING = True +_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME = "minimize-log-loss" +_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE = "classification" +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + { + # required inputs + "targetColumn": _TEST_TRAINING_TARGET_COLUMN, + "transformations": _TEST_TRAINING_COLUMN_TRANSFORMATIONS, + "trainBudgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + # optional inputs + "weightColumnName": _TEST_TRAINING_WEIGHT_COLUMN, + "disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING, + "predictionType": _TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + "optimizationObjective": _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + "optimizationObjectiveRecallValue": None, + "optimizationObjectivePrecisionValue": None, + }, + struct_pb2.Value(), +) + +_TEST_DATASET_NAME = "test-dataset-name" + +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_TRAINING_FRACTION_SPLIT = 0.6 +_TEST_VALIDATION_FRACTION_SPLIT = 0.2 +_TEST_TEST_FRACTION_SPLIT = 0.2 +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" + +_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" + +_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345" + +_TEST_PIPELINE_RESOURCE_NAME = ( + "projects/my-project/locations/us-central1/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_tabular(): + ds = mock.MagicMock(datasets.TabularDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_dataset_nontabular(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +class TestAutoMLTabularTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises(self, mock_dataset_tabular, sync): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_tabular, sync + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state diff --git a/tests/unit/aiplatform/test_automl_text_training_jobs.py b/tests/unit/aiplatform/test_automl_text_training_jobs.py new file mode 100644 index 0000000000..101ff79ef5 --- /dev/null +++ b/tests/unit/aiplatform/test_automl_text_training_jobs.py @@ -0,0 +1,618 @@ +import pytest +import importlib +from unittest import mock + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) +from google.cloud.aiplatform.v1.schema.trainingjob import ( + definition_v1 as training_job_inputs, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_TEXT = schema.dataset.metadata.text + +_TEST_PREDICTION_TYPE_CLASSIFICATION = "classification" +_TEST_CLASSIFICATION_MULTILABEL = True +_TEST_PREDICTION_TYPE_EXTRACTION = "extraction" +_TEST_PREDICTION_TYPE_SENTIMENT = "sentiment" +_TEST_SENTIMENT_MAX = 10 + +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_MODEL_ID = "98777645321" + +_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION = training_job_inputs.AutoMlTextClassificationInputs( + multi_label=_TEST_CLASSIFICATION_MULTILABEL +) +_TEST_TRAINING_TASK_INPUTS_EXTRACTION = training_job_inputs.AutoMlTextExtractionInputs() +_TEST_TRAINING_TASK_INPUTS_SENTIMENT = training_job_inputs.AutoMlTextSentimentInputs( + sentiment_max=_TEST_SENTIMENT_MAX +) + +_TEST_FRACTION_SPLIT_TRAINING = 0.6 +_TEST_FRACTION_SPLIT_VALIDATION = 0.2 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" +) + +_TEST_PIPELINE_RESOURCE_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_text(): + ds = mock.MagicMock(datasets.TextDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_model(): + model = mock.MagicMock(models.Model) + model.name = _TEST_MODEL_ID + model._latest_future = None + model._gca_resource = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, name=_TEST_MODEL_NAME, + ) + yield model + + +class TestAutoMLTextTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_all_parameters_classification(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert ( + job._training_task_definition + == schema.training_job.definition.automl_text_classification + ) + assert ( + job._training_task_inputs_dict + == training_job_inputs.AutoMlTextClassificationInputs( + multi_label=_TEST_CLASSIFICATION_MULTILABEL + ) + ) + + def test_init_all_parameters_extraction(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_EXTRACTION, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert ( + job._training_task_definition + == schema.training_job.definition.automl_text_extraction + ) + assert ( + job._training_task_inputs_dict + == training_job_inputs.AutoMlTextExtractionInputs() + ) + + def test_init_all_parameters_sentiment(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_SENTIMENT, + sentiment_max=_TEST_SENTIMENT_MAX, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert ( + job._training_task_definition + == schema.training_job.definition.automl_text_sentiment + ) + assert ( + job._training_task_inputs_dict + == training_job_inputs.AutoMlTextSentimentInputs( + sentiment_max=_TEST_SENTIMENT_MAX + ) + ) + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_training_job( + self, + mock_pipeline_service_create, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Text Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_classification( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """Create and run an AutoML Text Classification training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_extraction( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """Create and run an AutoML Text Extraction training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_EXTRACTION, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_extraction, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_EXTRACTION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_sentiment( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """Create and run an AutoML Text Sentiment training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_SENTIMENT, + sentiment_max=10, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_sentiment, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_SENTIMENT, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_text, + mock_model_service_get, + mock_model, + sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type="classification", + multi_label=True, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + model_display_name=None, # Omit model_display_name + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises(self, mock_dataset_text, sync): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type="classification", + multi_label=True, + ) + + job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_text, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_text, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() diff --git a/tests/unit/aiplatform/test_automl_video_training_jobs.py b/tests/unit/aiplatform/test_automl_video_training_jobs.py new file mode 100644 index 0000000000..66f1692fcf --- /dev/null +++ b/tests/unit/aiplatform/test_automl_video_training_jobs.py @@ -0,0 +1,463 @@ +import pytest +import importlib +from unittest import mock + +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_VIDEO = schema.dataset.metadata.video + +_TEST_MODEL_TYPE_CLOUD = "CLOUD" +_TEST_MODEL_TYPE_MOBILE = "MOBILE_VERSATILE_1" + +_TEST_PREDICTION_TYPE_VAR = "action_recognition" +_TEST_PREDICTION_TYPE_VCN = "classification" +_TEST_PREDICTION_TYPE_VOR = "object_tracking" + +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_MODEL_ID = "98777645321" # TODO + +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + {"modelType": "CLOUD"}, struct_pb2.Value(), +) + +_TEST_FRACTION_SPLIT_TRAINING = 0.8 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" +) + +_TEST_PIPELINE_RESOURCE_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_video(): + ds = mock.MagicMock(datasets.VideoDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_model(): + model = mock.MagicMock(models.Model) + model.name = _TEST_MODEL_ID + model._latest_future = None + model._exception = None + model._gca_resource = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, name=_TEST_MODEL_NAME, + ) + yield model + + +class TestAutoMLVideoTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_all_parameters(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert job._model_type == _TEST_MODEL_TYPE_CLOUD + assert job._prediction_type == _TEST_PREDICTION_TYPE_VCN + + def test_init_wrong_parameters(self): + """Ensure correct exceptions are raised when initializing with invalid args""" + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(ValueError, match=r"not a supported prediction type"): + training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, prediction_type="abcdefg", + ) + + with pytest.raises(ValueError, match=r"not a supported model_type for"): + training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type="abcdefg", + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_training_job( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_video, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_video, + mock_model_service_get, + mock_model, + sync, + ): + """Create and run an AutoML ICN training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_video, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_dataset_video, sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_video, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_video, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state diff --git a/tests/unit/aiplatform/test_base.py b/tests/unit/aiplatform/test_base.py new file mode 100644 index 0000000000..97f35b9476 --- /dev/null +++ b/tests/unit/aiplatform/test_base.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +import pytest +import time +from typing import Optional + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer + + +class _TestClass(base.FutureManager): + def __init__(self, x): + self.x = x + super().__init__() + + @classmethod + def _empty_constructor(cls): + self = cls.__new__(cls) + base.FutureManager.__init__(self) + self.x = None + return self + + def _sync_object_with_future_result(self, result): + self.x = result.x + + @classmethod + @base.optional_sync() + def create(cls, x: int, sync=True) -> "_TestClass": + time.sleep(1) + return cls(x) + + @base.optional_sync() + def add(self, a: "_TestClass", sync=True) -> None: + time.sleep(1) + return self._add(a=a, sync=sync) + + def _add(self, a: "_TestClass", sync=True) -> None: + self.x = self.x + a.x + + +class _TestClassDownStream(_TestClass): + @base.optional_sync(construct_object_on_arg="a") + def add_and_create_new( + self, a: Optional["_TestClass"] = None, sync=True + ) -> _TestClass: + time.sleep(1) + if a: + return _TestClass(self.x + a.x) + return None + + @base.optional_sync(return_input_arg="a", bind_future_to_self=False) + def add_to_input_arg(self, a: "_TestClass", sync=True) -> _TestClass: + time.sleep(1) + a._add(self) + return a + + +class TestFutureManager: + def setup_method(self): + reload(initializer) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_task(self, sync): + a = _TestClass.create(10, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + a.wait() + assert a._latest_future is None + assert a.x == 10 + assert isinstance(a, _TestClass) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_task(self, sync): + _latest_future = None + + a = _TestClass.create(10, sync=sync) + b = _TestClass.create(7, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + b.add(a, sync=sync) + + if not sync: + assert b._latest_future is not _latest_future + b.wait() + + assert a._latest_future is None + assert a.x == 10 + assert b._latest_future is None + assert b.x == 17 + assert isinstance(a, _TestClass) + assert isinstance(b, _TestClass) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_and_create_new_task(self, sync): + _latest_future = None + + a = _TestClass.create(10, sync=sync) + b = _TestClassDownStream.create(7, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + c = b.add_and_create_new(a, sync=sync) + + if not sync: + assert b._latest_future is not _latest_future + assert c.x is None + assert c._latest_future is not None + c.wait() + + assert a._latest_future is None + assert a.x == 10 + assert b._latest_future is None + assert b.x == 7 + assert c._latest_future is None + assert c.x == 17 + assert isinstance(a, _TestClass) + assert isinstance(b, _TestClassDownStream) + assert isinstance(c, _TestClass) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_and_not_create_new_task(self, sync): + _latest_future = None + + b = _TestClassDownStream.create(7, sync=sync) + if not sync: + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + c = b.add_and_create_new(None, sync=sync) + + if not sync: + assert b._latest_future is not _latest_future + b.wait() + + assert c is None + + assert b._latest_future is None + assert b.x == 7 + assert isinstance(b, _TestClassDownStream) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_return_arg(self, sync): + _latest_future = None + + a = _TestClass.create(10, sync=sync) + b = _TestClassDownStream.create(7, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + c = b.add_to_input_arg(a, sync=sync) + + if not sync: + assert b._latest_future is _latest_future + assert c.x is None + assert c._latest_future is not None + assert c is a + c.wait() + + assert a._latest_future is None + assert a.x == 17 + assert b._latest_future is None + assert b.x == 7 + assert c._latest_future is None + assert c.x == 17 + assert isinstance(a, _TestClass) + assert isinstance(b, _TestClassDownStream) + assert isinstance(c, _TestClass) diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py new file mode 100644 index 0000000000..918f753dbf --- /dev/null +++ b/tests/unit/aiplatform/test_datasets.py @@ -0,0 +1,1190 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import pytest + +from unittest import mock +from importlib import reload +from unittest.mock import patch + +from google.api_core import operation +from google.auth.exceptions import GoogleAuthError +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema + +from google.cloud.aiplatform_v1.services.dataset_service import ( + client as dataset_service_client, +) + +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + dataset_service as gca_dataset_service, + encryption_spec as gca_encryption_spec, + io as gca_io, +) + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_ALT_PROJECT = "test-project_alt" + +_TEST_ALT_LOCATION = "europe-west4" +_TEST_INVALID_LOCATION = "us-central2" + +# dataset +_TEST_ID = "1028944691210842416" +_TEST_DISPLAY_NAME = "my_dataset_1234" +_TEST_DATA_LABEL_ITEMS = None + +_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}" +_TEST_ALT_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_ALT_LOCATION}/datasets/{_TEST_ID}" +) +_TEST_INVALID_NAME = f"prj/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/{_TEST_ID}" + +# metadata_schema_uri +_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular +_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image +_TEST_METADATA_SCHEMA_URI_IMAGE = schema.dataset.metadata.image +_TEST_METADATA_SCHEMA_URI_TEXT = schema.dataset.metadata.text +_TEST_METADATA_SCHEMA_URI_VIDEO = schema.dataset.metadata.video + +# import_schema_uri +_TEST_IMPORT_SCHEMA_URI_IMAGE = ( + schema.dataset.ioformat.image.single_label_classification +) +_TEST_IMPORT_SCHEMA_URI_TEXT = schema.dataset.ioformat.text.single_label_classification +_TEST_IMPORT_SCHEMA_URI = schema.dataset.ioformat.image.single_label_classification +_TEST_IMPORT_SCHEMA_URI_VIDEO = schema.dataset.ioformat.video.classification + +# datasources +_TEST_SOURCE_URI_GCS = "gs://my-bucket/my_index_file.jsonl" +_TEST_SOURCE_URIS_GCS = [ + "gs://my-bucket/index_file_1.jsonl", + "gs://my-bucket/index_file_2.jsonl", + "gs://my-bucket/index_file_3.jsonl", +] +_TEST_SOURCE_URI_BQ = "bigquery://my-project/my-dataset" +_TEST_INVALID_SOURCE_URIS = ["gs://my-bucket/index_file_1.jsonl", 123] + +# request_metadata +_TEST_REQUEST_METADATA = () + +# dataset_metadata +_TEST_NONTABULAR_DATASET_METADATA = None +_TEST_METADATA_TABULAR_GCS = { + "input_config": {"gcs_source": {"uri": [_TEST_SOURCE_URI_GCS]}} +} +_TEST_METADATA_TABULAR_BQ = { + "input_config": {"bigquery_source": {"uri": _TEST_SOURCE_URI_BQ}} +} + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + +# misc +_TEST_OUTPUT_DIR = "gs://my-output-bucket" + +_TEST_DATASET_LIST = [ + gca_dataset.Dataset( + display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + gca_dataset.Dataset( + display_name="d", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR + ), + gca_dataset.Dataset( + display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + gca_dataset.Dataset( + display_name="e", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT + ), + gca_dataset.Dataset( + display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY = "create_time desc" + + +@pytest.fixture +def get_dataset_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + name=_TEST_NAME, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_without_name_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_image_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_tabular_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_text_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_video_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def create_dataset_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "create_dataset" + ) as create_dataset_mock: + create_dataset_lro_mock = mock.Mock(operation.Operation) + create_dataset_lro_mock.result.return_value = gca_dataset.Dataset( + name=_TEST_NAME, + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.return_value = create_dataset_lro_mock + yield create_dataset_mock + + +@pytest.fixture +def delete_dataset_mock(): + with mock.patch.object( + dataset_service_client.DatasetServiceClient, "delete_dataset" + ) as delete_dataset_mock: + delete_dataset_lro_mock = mock.Mock(operation.Operation) + delete_dataset_lro_mock.result.return_value = ( + gca_dataset_service.DeleteDatasetRequest() + ) + delete_dataset_mock.return_value = delete_dataset_lro_mock + yield delete_dataset_mock + + +@pytest.fixture +def import_data_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "import_data" + ) as import_data_mock: + import_data_mock.return_value = mock.Mock(operation.Operation) + yield import_data_mock + + +@pytest.fixture +def export_data_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "export_data" + ) as export_data_mock: + export_data_mock.return_value = mock.Mock(operation.Operation) + yield export_data_mock + + +@pytest.fixture +def list_datasets_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "list_datasets" + ) as list_datasets_mock: + list_datasets_mock.return_value = _TEST_DATASET_LIST + yield list_datasets_mock + + +# TODO(b/171333554): Move reusable test fixtures to conftest.py file +class TestDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset(dataset_name=_TEST_NAME) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_id_only_with_project_and_location( + self, get_dataset_mock + ): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset( + dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_project_and_location(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset( + dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset( + dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_project_and_alt_location(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(RuntimeError): + datasets._Dataset( + dataset_name=_TEST_NAME, + project=_TEST_PROJECT, + location=_TEST_ALT_LOCATION, + ) + + def test_init_dataset_with_id_only(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + datasets._Dataset(dataset_name=_TEST_ID) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_without_name_mock") + @patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "", "GOOGLE_APPLICATION_CREDENTIALS": ""} + ) + def test_init_dataset_with_id_only_without_project_or_location(self): + with pytest.raises(GoogleAuthError): + datasets._Dataset( + dataset_name=_TEST_ID, + credentials=auth_credentials.AnonymousCredentials(), + ) + + def test_init_dataset_with_location_override(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + datasets._Dataset(dataset_name=_TEST_ID, location=_TEST_ALT_LOCATION) + get_dataset_mock.assert_called_once_with(name=_TEST_ALT_NAME) + + @pytest.mark.usefixtures("get_dataset_mock") + def test_init_dataset_with_invalid_name(self): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + datasets._Dataset(dataset_name=_TEST_INVALID_NAME) + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_dataset( + self, create_dataset_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset_nontabular(self, create_dataset_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_mock") + def test_create_dataset_tabular(self, create_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + + datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + bq_source=_TEST_SOURCE_URI_BQ, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=_TEST_SOURCE_URI_GCS, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=_TEST_SOURCE_URI_GCS, + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.usefixtures("get_dataset_mock") + def test_export_data(self, export_data_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset(dataset_name=_TEST_NAME) + + my_dataset.export_data(output_dir=_TEST_OUTPUT_DIR) + + expected_export_config = gca_dataset.ExportDataConfig( + gcs_destination=gca_io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_DIR) + ) + + export_data_mock.assert_called_once_with( + name=_TEST_NAME, export_config=expected_export_config + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=_TEST_SOURCE_URI_GCS, + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_dataset(self, delete_dataset_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + my_dataset.delete(sync=sync) + + if not sync: + my_dataset.wait() + + delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name) + + +class TestImageDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_image(self, get_dataset_image_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.ImageDataset(dataset_name=_TEST_NAME) + get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + def test_init_dataset_non_image(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): + datasets.ImageDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_image_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets.ImageDataset.create( + display_name=_TEST_DISPLAY_NAME, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_image_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.ImageDataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + ) + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_image_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.ImageDataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_image_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.ImageDataset.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + +class TestTabularDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_tabular(self, get_dataset_tabular_mock): + + datasets.TabularDataset(dataset_name=_TEST_NAME) + get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_image_mock") + def test_init_dataset_non_tabular(self): + + with pytest.raises(ValueError): + datasets.TabularDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset_with_default_encryption_key( + self, create_dataset_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets.TabularDataset.create( + display_name=_TEST_DISPLAY_NAME, bq_source=_TEST_SOURCE_URI_BQ, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + + my_dataset = datasets.TabularDataset.create( + display_name=_TEST_DISPLAY_NAME, + bq_source=_TEST_SOURCE_URI_BQ, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + def test_no_import_data_method(self): + + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + with pytest.raises(NotImplementedError): + my_dataset.import_data() + + def test_list_dataset(self, list_datasets_mock): + + ds_list = aiplatform.TabularDataset.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY + ) + + list_datasets_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + # Ensure returned list is smaller since it filtered out non-tabular datasets + assert len(ds_list) < len(_TEST_DATASET_LIST) + + for ds in ds_list: + assert type(ds) == aiplatform.TabularDataset + + def test_list_dataset_no_order_or_filter(self, list_datasets_mock): + + ds_list = aiplatform.TabularDataset.list() + + list_datasets_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": None} + ) + + # Ensure returned list is smaller since it filtered out non-tabular datasets + assert len(ds_list) < len(_TEST_DATASET_LIST) + + for ds in ds_list: + assert type(ds) == aiplatform.TabularDataset + + +class TestTextDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_text(self, get_dataset_text_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.TextDataset(dataset_name=_TEST_NAME) + get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_image_mock") + def test_init_dataset_non_text(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): + datasets.TextDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_text_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets.TextDataset.create( + display_name=_TEST_DISPLAY_NAME, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_text_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.TextDataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + ) + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_text_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME + ) + + my_dataset = datasets.TextDataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_text_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.TextDataset.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + +class TestVideoDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_video(self, get_dataset_video_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.VideoDataset(dataset_name=_TEST_NAME) + get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + def test_init_dataset_non_video(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): + datasets.VideoDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_video_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME + ) + + my_dataset = datasets.VideoDataset.create( + display_name=_TEST_DISPLAY_NAME, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_video_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.VideoDataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + ) + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_video_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.VideoDataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_video_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.VideoDataset.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset diff --git a/tests/unit/aiplatform/test_end_to_end.py b/tests/unit/aiplatform/test_end_to_end.py new file mode 100644 index 0000000000..69c5517a69 --- /dev/null +++ b/tests/unit/aiplatform/test_end_to_end.py @@ -0,0 +1,462 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from importlib import reload + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +import test_datasets +from test_datasets import create_dataset_mock # noqa: F401 +from test_datasets import get_dataset_mock # noqa: F401 +from test_datasets import import_data_mock # noqa: F401 + +import test_endpoints +from test_endpoints import create_endpoint_mock # noqa: F401 +from test_endpoints import get_endpoint_mock # noqa: F401 +from test_endpoints import predict_client_predict_mock # noqa: F401 + +from test_models import deploy_model_mock # noqa: F401 + +import test_training_jobs +from test_training_jobs import mock_model_service_get # noqa: F401 +from test_training_jobs import mock_pipeline_service_create # noqa: F401 +from test_training_jobs import mock_pipeline_service_get # noqa: F401 +from test_training_jobs import ( # noqa: F401 + mock_pipeline_service_create_and_get_with_fail, +) +from test_training_jobs import mock_python_package_to_gcs # noqa: F401 + +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +# dataset_encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + + +class TestEndToEnd: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures( + "get_dataset_mock", + "create_endpoint_mock", + "get_endpoint_mock", + "deploy_model_mock", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_dataset_create_to_model_predict( + self, + create_dataset_mock, # noqa: F811 + import_data_mock, # noqa: F811 + predict_client_predict_mock, # noqa: F811 + mock_python_package_to_gcs, # noqa: F811 + mock_pipeline_service_create, # noqa: F811 + mock_model_service_get, # noqa: F811 + mock_pipeline_service_get, # noqa: F811 + sync, + ): + + aiplatform.init( + project=test_datasets._TEST_PROJECT, + staging_bucket=test_training_jobs._TEST_BUCKET_NAME, + credentials=test_training_jobs._TEST_CREDENTIALS, + ) + + my_dataset = aiplatform.ImageDataset.create( + display_name=test_datasets._TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=test_datasets._TEST_SOURCE_URI_GCS, + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + job = aiplatform.CustomTrainingJob( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + dataset=my_dataset, + base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR, + args=test_training_jobs._TEST_RUN_ARGS, + replica_count=1, + machine_type=test_training_jobs._TEST_MACHINE_TYPE, + accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE, + accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT, + model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + training_fraction_split=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + created_endpoint = models.Endpoint.create( + display_name=test_endpoints._TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_endpoint = model_from_job.deploy( + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync + ) + + endpoint_deploy_return = created_endpoint.deploy(model_from_job, sync=sync) + + assert endpoint_deploy_return is None + + if not sync: + my_endpoint.wait() + created_endpoint.wait() + + test_prediction = created_endpoint.predict( + instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], parameters={"param": 3.0} + ) + + true_prediction = models.Prediction( + predictions=test_endpoints._TEST_PREDICTION, + deployed_model_id=test_endpoints._TEST_ID, + ) + + assert true_prediction == test_prediction + predict_client_predict_mock.assert_called_once_with( + endpoint=test_endpoints._TEST_ENDPOINT_NAME, + instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], + parameters={"param": 3.0}, + ) + + expected_dataset = gca_dataset.Dataset( + display_name=test_datasets._TEST_DISPLAY_NAME, + metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=test_datasets._TEST_PARENT, + dataset=expected_dataset, + metadata=test_datasets._TEST_REQUEST_METADATA, + ) + + import_data_mock.assert_called_once_with( + name=test_datasets._TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = test_datasets._TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME, + project=test_training_jobs._TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = test_training_jobs._TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": test_training_jobs._TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": test_training_jobs._TEST_MACHINE_TYPE, + "acceleratorType": test_training_jobs._TEST_ACCELERATOR_TYPE, + "acceleratorCount": test_training_jobs._TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=my_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": { + "output_uri_prefix": test_training_jobs._TEST_BASE_OUTPUT_DIR + }, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with( + name=test_training_jobs._TEST_MODEL_NAME + ) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "get_dataset_mock", + "create_endpoint_mock", + "get_endpoint_mock", + "deploy_model_mock", + ) + def test_dataset_create_to_model_predict_with_pipeline_fail( + self, + create_dataset_mock, # noqa: F811 + import_data_mock, # noqa: F811 + mock_python_package_to_gcs, # noqa: F811 + mock_pipeline_service_create_and_get_with_fail, # noqa: F811 + mock_model_service_get, # noqa: F811 + ): + + sync = False + + aiplatform.init( + project=test_datasets._TEST_PROJECT, + staging_bucket=test_training_jobs._TEST_BUCKET_NAME, + credentials=test_training_jobs._TEST_CREDENTIALS, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = aiplatform.ImageDataset.create( + display_name=test_datasets._TEST_DISPLAY_NAME, sync=sync, + ) + + my_dataset.import_data( + gcs_source=test_datasets._TEST_SOURCE_URI_GCS, + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + job = aiplatform.CustomTrainingJob( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + created_endpoint = models.Endpoint.create( + display_name=test_endpoints._TEST_DISPLAY_NAME, sync=sync, + ) + + model_from_job = job.run( + dataset=my_dataset, + base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR, + args=test_training_jobs._TEST_RUN_ARGS, + replica_count=1, + machine_type=test_training_jobs._TEST_MACHINE_TYPE, + accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE, + accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT, + model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + training_fraction_split=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + my_endpoint = model_from_job.deploy(sync=sync) + my_endpoint.wait() + + with pytest.raises(RuntimeError): + endpoint_deploy_return = created_endpoint.deploy(model_from_job, sync=sync) + assert endpoint_deploy_return is None + created_endpoint.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=test_datasets._TEST_DISPLAY_NAME, + metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=test_datasets._TEST_PARENT, + dataset=expected_dataset, + metadata=test_datasets._TEST_REQUEST_METADATA, + ) + + import_data_mock.assert_called_once_with( + name=test_datasets._TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = test_datasets._TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME, + project=test_training_jobs._TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = test_training_jobs._TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": test_training_jobs._TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": test_training_jobs._TEST_MACHINE_TYPE, + "acceleratorType": test_training_jobs._TEST_ACCELERATOR_TYPE, + "acceleratorCount": test_training_jobs._TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=my_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": { + "output_uri_prefix": test_training_jobs._TEST_BASE_OUTPUT_DIR + }, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create_and_get_with_fail[0].assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert ( + job._gca_resource + is mock_pipeline_service_create_and_get_with_fail[1].return_value + ) + + mock_model_service_get.assert_not_called() + + with pytest.raises(RuntimeError): + job.get_model() + + assert job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py new file mode 100644 index 0000000000..ea74c89e5e --- /dev/null +++ b/tests/unit/aiplatform/test_endpoints.py @@ -0,0 +1,1079 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from unittest import mock +from importlib import reload +from datetime import datetime, timedelta + +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + client as prediction_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.types import ( + endpoint as gca_endpoint_v1beta1, + machine_resources as gca_machine_resources_v1beta1, + prediction_service as gca_prediction_service_v1beta1, + endpoint_service as gca_endpoint_service_v1beta1, +) + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client, +) +from google.cloud.aiplatform_v1.services.prediction_service import ( + client as prediction_service_client, +) +from google.cloud.aiplatform_v1.types import ( + endpoint as gca_endpoint, + model as gca_model, + machine_resources as gca_machine_resources, + prediction_service as gca_prediction_service, + endpoint_service as gca_endpoint_service, + encryption_spec as gca_encryption_spec, +) + +_TEST_PROJECT = "test-project" +_TEST_PROJECT_2 = "test-project-2" +_TEST_LOCATION = "us-central1" +_TEST_LOCATION_2 = "europe-west4" + +_TEST_ENDPOINT_NAME = "test-endpoint" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_DISPLAY_NAME_2 = "test-display-name-2" +_TEST_ID = "1028944691210842416" +_TEST_ID_2 = "4366591682456584192" +_TEST_DESCRIPTION = "test-description" + +_TEST_ENDPOINT_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}" +) +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ID}" +) + +_TEST_MODEL_ID = "1028944691210842416" +_TEST_PREDICTION = [[1.0, 2.0, 3.0], [3.0, 3.0, 1.0]] +_TEST_INSTANCES = [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]] +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) + +_TEST_DEPLOYED_MODELS = [ + gca_endpoint.DeployedModel(id=_TEST_ID, display_name=_TEST_DISPLAY_NAME), + gca_endpoint.DeployedModel(id=_TEST_ID_2, display_name=_TEST_DISPLAY_NAME_2), +] + +_TEST_MACHINE_TYPE = "n1-standard-32" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 + +_TEST_EXPLANATIONS = [ + gca_prediction_service_v1beta1.explanation.Explanation(attributions=[]) +] + +_TEST_ATTRIBUTIONS = [ + gca_prediction_service_v1beta1.explanation.Attribution( + baseline_output_value=1.0, + instance_output_value=2.0, + feature_attributions=3.0, + output_index=[1, 2, 3], + output_display_name="abc", + approximation_error=6.0, + output_name="xyz", + ) +] + +_TEST_EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( + inputs={ + "features": aiplatform.explain.ExplanationMetadata.InputMetadata( + { + "input_tensor_name": "dense_input", + "encoding": "BAG_OF_FEATURES", + "modality": "numeric", + "index_feature_mapping": ["abc", "def", "ghj"], + } + ) + }, + outputs={ + "medv": aiplatform.explain.ExplanationMetadata.OutputMetadata( + {"output_tensor_name": "dense_2"} + ) + }, +) +_TEST_EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( + {"sampled_shapley_attribution": {"path_count": 10}} +) + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + + +_TEST_ENDPOINT_LIST = [ + gca_endpoint.Endpoint( + display_name="aac", create_time=datetime.now() - timedelta(minutes=15) + ), + gca_endpoint.Endpoint( + display_name="aab", create_time=datetime.now() - timedelta(minutes=5) + ), + gca_endpoint.Endpoint( + display_name="aaa", create_time=datetime.now() - timedelta(minutes=10) + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY_CREATE_TIME = "create_time desc" +_TEST_LIST_ORDER_BY_DISPLAY_NAME = "display_name" + + +@pytest.fixture +def get_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_endpoint_mock + + +@pytest.fixture +def get_endpoint_with_models_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + deployed_models=_TEST_DEPLOYED_MODELS, + ) + yield get_endpoint_mock + + +@pytest.fixture +def get_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_DISPLAY_NAME, name=_TEST_MODEL_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def create_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "create_endpoint" + ) as create_endpoint_mock: + create_endpoint_lro_mock = mock.Mock(ga_operation.Operation) + create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint( + name=_TEST_ENDPOINT_NAME, display_name=_TEST_DISPLAY_NAME + ) + create_endpoint_mock.return_value = create_endpoint_lro_mock + yield create_endpoint_mock + + +@pytest.fixture +def deploy_model_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint.DeployedModel( + model=_TEST_MODEL_NAME, display_name=_TEST_DISPLAY_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def deploy_model_with_explanations_mock(): + with mock.patch.object( + endpoint_service_client_v1beta1.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint_v1beta1.DeployedModel( + model=_TEST_MODEL_NAME, display_name=_TEST_DISPLAY_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service_v1beta1.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def undeploy_model_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "undeploy_model" + ) as undeploy_model_mock: + undeploy_model_lro_mock = mock.Mock(ga_operation.Operation) + undeploy_model_lro_mock.result.return_value = ( + gca_endpoint_service.UndeployModelResponse() + ) + undeploy_model_mock.return_value = undeploy_model_lro_mock + yield undeploy_model_mock + + +@pytest.fixture +def delete_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "delete_endpoint" + ) as delete_endpoint_mock: + delete_endpoint_lro_mock = mock.Mock(ga_operation.Operation) + delete_endpoint_lro_mock.result.return_value = ( + gca_endpoint_service.DeleteEndpointRequest() + ) + delete_endpoint_mock.return_value = delete_endpoint_lro_mock + yield delete_endpoint_mock + + +@pytest.fixture +def sdk_private_undeploy_mock(): + """Mocks the high-level Endpoint._undeploy() SDK private method""" + with mock.patch.object(aiplatform.Endpoint, "_undeploy") as sdk_undeploy_mock: + sdk_undeploy_mock.return_value = None + yield sdk_undeploy_mock + + +@pytest.fixture +def sdk_undeploy_all_mock(): + """Mocks the high-level Endpoint.undeploy_all() SDK method""" + with mock.patch.object( + aiplatform.Endpoint, "undeploy_all" + ) as sdk_undeploy_all_mock: + sdk_undeploy_all_mock.return_value = None + yield sdk_undeploy_all_mock + + +@pytest.fixture +def list_endpoints_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "list_endpoints" + ) as list_endpoints_mock: + list_endpoints_mock.return_value = _TEST_ENDPOINT_LIST + yield list_endpoints_mock + + +@pytest.fixture +def create_client_mock(): + with mock.patch.object( + initializer.global_config, "create_client", autospec=True, + ) as create_client_mock: + create_client_mock.return_value = mock.Mock( + spec=endpoint_service_client.EndpointServiceClient + ) + yield create_client_mock + + +@pytest.fixture +def predict_client_predict_mock(): + with mock.patch.object( + prediction_service_client.PredictionServiceClient, "predict" + ) as predict_mock: + predict_mock.return_value = gca_prediction_service.PredictResponse( + deployed_model_id=_TEST_MODEL_ID + ) + predict_mock.return_value.predictions.extend(_TEST_PREDICTION) + yield predict_mock + + +@pytest.fixture +def predict_client_explain_mock(): + with mock.patch.object( + prediction_service_client_v1beta1.PredictionServiceClient, "explain" + ) as predict_mock: + predict_mock.return_value = gca_prediction_service_v1beta1.ExplainResponse( + deployed_model_id=_TEST_MODEL_ID, + ) + predict_mock.return_value.predictions.extend(_TEST_PREDICTION) + predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS) + predict_mock.return_value.explanations[0].attributions.extend( + _TEST_ATTRIBUTIONS + ) + yield predict_mock + + +class TestEndpoint: + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_constructor(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + models.Endpoint(_TEST_ENDPOINT_NAME) + create_client_mock.assert_has_calls( + [ + mock.call( + client_class=utils.EndpointClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION, + prediction_client=False, + ), + mock.call( + client_class=utils.PredictionClientWithOverride, + credentials=None, + location_override=_TEST_LOCATION, + prediction_client=True, + ), + ] + ) + + def test_constructor_with_endpoint_id(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(_TEST_ID) + get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) + + def test_constructor_with_endpoint_name(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(_TEST_ENDPOINT_NAME) + get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) + + def test_constructor_with_custom_project(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2) + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) + + def test_constructor_with_custom_location(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) + get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) + + def test_constructor_with_custom_credentials(self, create_client_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = auth_credentials.AnonymousCredentials() + + models.Endpoint(_TEST_ENDPOINT_NAME, credentials=creds) + create_client_mock.assert_has_calls( + [ + mock.call( + client_class=utils.EndpointClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=False, + ), + mock.call( + client_class=utils.PredictionClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=True, + ), + ] + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_endpoint( + self, create_endpoint_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + my_endpoint = models.Endpoint.create(display_name=_TEST_DISPLAY_NAME, sync=sync) + + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC + ) + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, endpoint=expected_endpoint, metadata=(), + ) + + expected_endpoint.name = _TEST_ENDPOINT_NAME + assert my_endpoint._gca_resource == expected_endpoint + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create(self, create_endpoint_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_endpoint = models.Endpoint.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC + ) + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, endpoint=expected_endpoint, metadata=(), + ) + + expected_endpoint.name = _TEST_ENDPOINT_NAME + assert my_endpoint._gca_resource == expected_endpoint + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_with_description(self, create_endpoint_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_endpoint = models.Endpoint.create( + display_name=_TEST_DISPLAY_NAME, description=_TEST_DESCRIPTION, sync=sync + ) + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, description=_TEST_DESCRIPTION, + ) + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, endpoint=expected_endpoint, metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(test_model, sync=sync) + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_display_name(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, deployed_model_display_name=_TEST_DISPLAY_NAME, sync=sync + ) + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=_TEST_DISPLAY_NAME, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_80(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=80, sync=sync) + + if not sync: + test_endpoint.wait() + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_120(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=120, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_negative(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=-18, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_min_replica(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, min_replica_count=-1, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_max_replica(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_split(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_split={"a": 99}, sync=sync) + + @pytest.mark.usefixtures("get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 100}, + ) + + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=70, sync=sync) + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"model1": 30, "0": 70}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_traffic_split(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 100}, + ) + + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, traffic_split={"model1": 30, "0": 70}, sync=sync + ) + + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"model1": 30, "0": 70}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources( + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + ) + expected_deployed_model = gca_endpoint_v1beta1.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + ) + deploy_model_with_explanations_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_min_replica_count(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, min_replica_count=2, sync=sync) + + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=2, max_replica_count=2, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_max_replica_count(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync) + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=2, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.parametrize( + "model1, model2, model3, percent", + [ + (100, None, None, 70), + (50, 50, None, 70), + (40, 60, None, 75), + (40, 60, None, 88), + (88, 12, None, 36), + (11, 89, None, 18), + (1, 99, None, 80), + (1, 2, 97, 68), + (99, 1, 0, 22), + (0, 0, 100, 18), + (7, 87, 6, 46), + ], + ) + def test_allocate_traffic(self, model1, model2, model3, percent): + old_split = {} + if model1 is not None: + old_split["model1"] = model1 + if model2 is not None: + old_split["model2"] = model2 + if model3 is not None: + old_split["model3"] = model3 + + new_split = models.Endpoint._allocate_traffic(old_split, percent) + new_split_sum = 0 + for model in new_split: + new_split_sum += new_split[model] + + assert new_split_sum == 100 + assert new_split["0"] == percent + + @pytest.mark.parametrize( + "model1, model2, model3, deployed_model", + [ + (100, None, None, "model1"), + (50, 50, None, "model1"), + (40, 60, None, "model2"), + (40, 60, None, "model1"), + (88, 12, None, "model1"), + (11, 89, None, "model1"), + (1, 99, None, "model2"), + (1, 2, 97, "model1"), + (99, 1, 0, "model2"), + (0, 0, 100, "model3"), + (7, 87, 6, "model2"), + ], + ) + def test_unallocate_traffic(self, model1, model2, model3, deployed_model): + old_split = {} + if model1 is not None: + old_split["model1"] = model1 + if model2 is not None: + old_split["model2"] = model2 + if model3 is not None: + old_split["model3"] = model3 + + new_split = models.Endpoint._unallocate_traffic(old_split, deployed_model) + new_split_sum = 0 + for model in new_split: + new_split_sum += new_split[model] + + assert new_split_sum == 100 or new_split_sum == 0 + assert new_split[deployed_model] == 0 + + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy(self, undeploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 100}, + ) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + assert dict(test_endpoint._gca_resource.traffic_split) == {"model1": 100} + test_endpoint.undeploy("model1", sync=sync) + if not sync: + test_endpoint.wait() + undeploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model_id="model1", + traffic_split={}, + # traffic_split={"model1": 0}, + metadata=(), + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 40, "model2": 60}, + ) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_endpoint.undeploy( + deployed_model_id="model1", + traffic_split={"model1": 0, "model2": 100}, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + undeploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model_id="model1", + traffic_split={"model2": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_raise_error_traffic_split_total(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_endpoint.undeploy( + deployed_model_id="model1", traffic_split={"model2": 99}, sync=sync + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_raise_error_undeployed_model_traffic(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_endpoint.undeploy( + deployed_model_id="model1", + traffic_split={"model1": 50, "model2": 50}, + sync=sync, + ) + + def test_predict(self, get_endpoint_mock, predict_client_predict_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_endpoint = models.Endpoint(_TEST_ID) + test_prediction = test_endpoint.predict( + instances=_TEST_INSTANCES, parameters={"param": 3.0} + ) + + true_prediction = models.Prediction( + predictions=_TEST_PREDICTION, deployed_model_id=_TEST_ID + ) + + assert true_prediction == test_prediction + predict_client_predict_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + ) + + def test_explain(self, get_endpoint_mock, predict_client_explain_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_endpoint = models.Endpoint(_TEST_ID) + test_prediction = test_endpoint.explain( + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + ) + expected_explanations = _TEST_EXPLANATIONS + expected_explanations[0].attributions.extend(_TEST_ATTRIBUTIONS) + + expected_prediction = models.Prediction( + predictions=_TEST_PREDICTION, + deployed_model_id=_TEST_ID, + explanations=expected_explanations, + ) + + assert expected_prediction == test_prediction + predict_client_explain_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + ) + + def test_list_models(self, get_endpoint_with_models_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + my_models = ept.list_models() + + assert my_models == _TEST_DEPLOYED_MODELS + + @pytest.mark.usefixtures("get_endpoint_with_models_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_all(self, sdk_private_undeploy_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + ept.undeploy_all(sync=sync) + + if not sync: + ept.wait() + + # undeploy_all() results in an undeploy() call for each deployed_model + sdk_private_undeploy_mock.assert_has_calls( + [ + mock.call(deployed_model_id=deployed_model.id, sync=sync) + for deployed_model in _TEST_DEPLOYED_MODELS + ], + any_order=True, + ) + + def test_list_endpoint_order_by_time(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in descending order of create_time""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_CREATE_TIME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ep_list[0].create_time > ep_list[1].create_time > ep_list[2].create_time + + def test_list_endpoint_order_by_display_name(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in order of display_name""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_DISPLAY_NAME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ( + ep_list[0].display_name < ep_list[1].display_name < ep_list[2].display_name + ) + + @pytest.mark.usefixtures("get_endpoint_with_models_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_endpoint_without_force( + self, sdk_undeploy_all_mock, delete_endpoint_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + ept.delete(sync=sync) + + if not sync: + ept.wait() + + # undeploy_all() should not be called unless force is set to True + sdk_undeploy_all_mock.assert_not_called() + + delete_endpoint_mock.assert_called_once_with(name=_TEST_ENDPOINT_NAME) + + @pytest.mark.usefixtures("get_endpoint_with_models_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_endpoint_with_force( + self, sdk_undeploy_all_mock, delete_endpoint_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + ept.delete(force=True, sync=sync) + + if not sync: + ept.wait() + + # undeploy_all() should be called if force is set to True + sdk_undeploy_all_mock.assert_called_once() + + delete_endpoint_mock.assert_called_once_with(name=_TEST_ENDPOINT_NAME) diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py new file mode 100644 index 0000000000..1d97ad2e9a --- /dev/null +++ b/tests/unit/aiplatform/test_initializer.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import os +import pytest +from unittest import mock + +import google.auth +from google.auth import credentials + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) + +_TEST_PROJECT = "test-project" +_TEST_PROJECT_2 = "test-project-2" +_TEST_LOCATION = "us-central1" +_TEST_LOCATION_2 = "europe-west4" +_TEST_INVALID_LOCATION = "test-invalid-location" +_TEST_EXPERIMENT = "test-experiment" +_TEST_STAGING_BUCKET = "test-bucket" + + +class TestInit: + def setup_method(self): + importlib.reload(initializer) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_project_sets_project(self): + initializer.global_config.init(project=_TEST_PROJECT) + assert initializer.global_config.project == _TEST_PROJECT + + def test_not_init_project_gets_default_project(self, monkeypatch): + def mock_auth_default(): + return None, _TEST_PROJECT + + monkeypatch.setattr(google.auth, "default", mock_auth_default) + assert initializer.global_config.project == _TEST_PROJECT + + def test_init_location_sets_location(self): + initializer.global_config.init(location=_TEST_LOCATION) + assert initializer.global_config.location == _TEST_LOCATION + + def test_not_init_location_gets_default_location(self): + assert initializer.global_config.location == constants.DEFAULT_REGION + + def test_init_location_with_invalid_location_raises(self): + with pytest.raises(ValueError): + initializer.global_config.init(location=_TEST_INVALID_LOCATION) + + def test_init_experiment_sets_experiment(self): + initializer.global_config.init(experiment=_TEST_EXPERIMENT) + assert initializer.global_config.experiment == _TEST_EXPERIMENT + + def test_init_staging_bucket_sets_staging_bucket(self): + initializer.global_config.init(staging_bucket=_TEST_STAGING_BUCKET) + assert initializer.global_config.staging_bucket == _TEST_STAGING_BUCKET + + def test_init_credentials_sets_credentials(self): + creds = credentials.AnonymousCredentials() + initializer.global_config.init(credentials=creds) + assert initializer.global_config.credentials is creds + + def test_common_location_path_returns_parent(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + true_resource_parent = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + assert true_resource_parent == initializer.global_config.common_location_path() + + def test_common_location_path_overrides(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + true_resource_parent = ( + f"projects/{_TEST_PROJECT_2}/locations/{_TEST_LOCATION_2}" + ) + assert true_resource_parent == initializer.global_config.common_location_path( + project=_TEST_PROJECT_2, location=_TEST_LOCATION_2 + ) + + def test_create_client_returns_client(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + client = initializer.global_config.create_client( + client_class=utils.ModelClientWithOverride + ) + assert client._client_class is model_service_client.ModelServiceClient + assert isinstance(client, utils.ModelClientWithOverride) + assert ( + client._transport._host == f"{_TEST_LOCATION}-{constants.API_BASE_PATH}:443" + ) + + def test_create_client_overrides(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = credentials.AnonymousCredentials() + client = initializer.global_config.create_client( + client_class=utils.ModelClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION_2, + prediction_client=True, + ) + assert isinstance(client, utils.ModelClientWithOverride) + assert ( + client._transport._host + == f"{_TEST_LOCATION_2}-{constants.API_BASE_PATH}:443" + ) + assert client._transport._credentials == creds + + def test_create_client_user_agent(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + client = initializer.global_config.create_client( + client_class=utils.ModelClientWithOverride + ) + + for wrapped_method in client._transport._wrapped_methods.values(): + # wrapped_method._metadata looks like: + # [('x-goog-api-client', 'model-builder/0.3.1 gl-python/3.7.6 grpc/1.30.0 gax/1.22.2 gapic/0.3.1')] + user_agent = wrapped_method._metadata[0][1] + assert user_agent.startswith("model-builder/") + + @pytest.mark.parametrize( + "init_location, location_override, expected_endpoint", + [ + ("us-central1", None, "us-central1-aiplatform.googleapis.com"), + ("us-central1", "europe-west4", "europe-west4-aiplatform.googleapis.com",), + ("asia-east1", None, "asia-east1-aiplatform.googleapis.com"), + ], + ) + def test_get_client_options( + self, init_location: str, location_override: str, expected_endpoint: str, + ): + initializer.global_config.init(location=init_location) + + assert ( + initializer.global_config.get_client_options( + location_override=location_override + ).api_endpoint + == expected_endpoint + ) + + +class TestThreadPool: + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize( + "cpu_count, expected", [(4, 20), (32, 32), (None, 4), (2, 10)] + ) + def test_max_workers(self, cpu_count, expected): + with mock.patch.object(os, "cpu_count") as cpu_count_mock: + cpu_count_mock.return_value = cpu_count + importlib.reload(initializer) + assert initializer.global_pool._max_workers == expected diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py new file mode 100644 index 0000000000..acc7317ebb --- /dev/null +++ b/tests/unit/aiplatform/test_jobs.py @@ -0,0 +1,639 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from unittest import mock +from importlib import reload +from unittest.mock import patch + +from google.cloud import storage +from google.cloud import bigquery + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) + +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job_v1beta1, + explanation as gca_explanation_v1beta1, + io as gca_io_v1beta1, + machine_resources as gca_machine_resources_v1beta1, +) + +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client + +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, + io as gca_io, + job_state as gca_job_state, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_ID = "1028944691210842416" +_TEST_ALT_ID = "8834795523125638878" +_TEST_DISPLAY_NAME = "my_job_1234" +_TEST_BQ_DATASET_ID = "bqDatasetId" +_TEST_BQ_JOB_ID = "123459876" +_TEST_BQ_MAX_RESULTS = 100 +_TEST_GCS_BUCKET_NAME = "my-bucket" + +_TEST_BQ_PATH = f"bq://projectId.{_TEST_BQ_DATASET_ID}" +_TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}" +_TEST_GCS_JSONL_SOURCE_URI = f"{_TEST_GCS_BUCKET_PATH}/bp_input_config.jsonl" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ALT_ID}" +) +_TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_ID}" +_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME = "test-batch-prediction-job" + +_TEST_BATCH_PREDICTION_GCS_SOURCE = "gs://example-bucket/folder/instance.jsonl" +_TEST_BATCH_PREDICTION_GCS_SOURCE_LIST = [ + "gs://example-bucket/folder/instance1.jsonl", + "gs://example-bucket/folder/instance2.jsonl", +] +_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX = "gs://example-bucket/folder/output" +_TEST_BATCH_PREDICTION_BQ_PREFIX = "ucaip-sample-tests" +_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL = ( + f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}" +) + +_TEST_JOB_STATE_SUCCESS = gca_job_state.JobState(4) +_TEST_JOB_STATE_RUNNING = gca_job_state.JobState(3) +_TEST_JOB_STATE_PENDING = gca_job_state.JobState(2) + +_TEST_GCS_INPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_GCS_JSONL_SOURCE_URI]), +) +_TEST_GCS_OUTPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + predictions_format="jsonl", + gcs_destination=gca_io.GcsDestination(output_uri_prefix=_TEST_GCS_BUCKET_PATH), +) + +_TEST_BQ_INPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io.BigQuerySource(input_uri=_TEST_BQ_PATH), +) +_TEST_BQ_OUTPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + predictions_format="bigquery", + bigquery_destination=gca_io.BigQueryDestination(output_uri=_TEST_BQ_PATH), +) + +_TEST_GCS_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_BUCKET_NAME +) +_TEST_BQ_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo( + bigquery_output_dataset=_TEST_BQ_PATH +) + +_TEST_EMPTY_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo() + +_TEST_GCS_BLOBS = [ + storage.Blob(name="some/path/prediction.jsonl", bucket=_TEST_GCS_BUCKET_NAME) +] + +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 +_TEST_STARTING_REPLICA_COUNT = 2 +_TEST_MAX_REPLICA_COUNT = 12 + +_TEST_LABEL = {"team": "experimentation", "trial_id": "x435"} + +_TEST_EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( + inputs={ + "features": { + "input_tensor_name": "dense_input", + "encoding": "BAG_OF_FEATURES", + "modality": "numeric", + "index_feature_mapping": ["abc", "def", "ghj"], + } + }, + outputs={"medv": {"output_tensor_name": "dense_2"}}, +) +_TEST_EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( + {"sampled_shapley_attribution": {"path_count": 10}} +) + +_TEST_JOB_GET_METHOD_NAME = "get_fake_job" +_TEST_JOB_LIST_METHOD_NAME = "list_fake_job" +_TEST_JOB_CANCEL_METHOD_NAME = "cancel_fake_job" +_TEST_JOB_DELETE_METHOD_NAME = "delete_fake_job" +_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/fakeJobs/{_TEST_ID}" + +# TODO(b/171333554): Move reusable test fixtures to conftest.py file + + +@pytest.fixture +def fake_job_getter_mock(): + with patch.object( + job_service_client.JobServiceClient, _TEST_JOB_GET_METHOD_NAME, create=True + ) as fake_job_getter_mock: + fake_job_getter_mock.return_value = {} + yield fake_job_getter_mock + + +@pytest.fixture +def fake_job_cancel_mock(): + with patch.object( + job_service_client.JobServiceClient, _TEST_JOB_CANCEL_METHOD_NAME, create=True + ) as fake_job_cancel_mock: + yield fake_job_cancel_mock + + +class TestJob: + class FakeJob(jobs._Job): + _job_type = "fake-job" + _resource_noun = "fakeJobs" + _getter_method = _TEST_JOB_GET_METHOD_NAME + _list_method = _TEST_JOB_LIST_METHOD_NAME + _cancel_method = _TEST_JOB_CANCEL_METHOD_NAME + _delete_method = _TEST_JOB_DELETE_METHOD_NAME + resource_name = _TEST_JOB_RESOURCE_NAME + + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + # Unit Tests + def test_init_job_class(self): + """ + Raises TypeError since abstract property '_getter_method' is not set, + the _Job class should only be instantiated through a child class. + """ + with pytest.raises(TypeError): + jobs._Job(job_name=_TEST_BATCH_PREDICTION_JOB_NAME) + + @pytest.mark.usefixtures("fake_job_getter_mock") + def test_cancel_mock_job(self, fake_job_cancel_mock): + """Create a fake `_Job` child class, and ensure the high-level cancel method works""" + fake_job = self.FakeJob(job_name=_TEST_JOB_RESOURCE_NAME) + fake_job.cancel() + + fake_job_cancel_mock.assert_called_once_with(name=_TEST_JOB_RESOURCE_NAME) + + +@pytest.fixture +def get_batch_prediction_job_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.side_effect = [ + gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_RUNNING, + ), + gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ), + gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ), + ] + yield get_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + create_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield create_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_with_explanations_mock(): + with mock.patch.object( + job_service_client_v1beta1.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + create_batch_prediction_job_mock.return_value = gca_batch_prediction_job_v1beta1.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield create_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_gcs_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_GCS_OUTPUT_CONFIG, + output_info=_TEST_GCS_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_bq_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_BQ_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_empty_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_EMPTY_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_running_bq_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_BQ_OUTPUT_INFO, + state=_TEST_JOB_STATE_RUNNING, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def storage_list_blobs_mock(): + with patch.object(storage.Client, "list_blobs") as list_blobs_mock: + list_blobs_mock.return_value = _TEST_GCS_BLOBS + yield list_blobs_mock + + +@pytest.fixture +def bq_list_rows_mock(): + with patch.object(bigquery.Client, "list_rows") as list_rows_mock: + list_rows_mock.return_value = mock.Mock(bigquery.table.RowIterator) + yield list_rows_mock + + +class TestBatchPredictionJob: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_batch_prediction_job(self, get_batch_prediction_job_mock): + jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + get_batch_prediction_job_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + def test_batch_prediction_job_status(self, get_batch_prediction_job_mock): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + # get_batch_prediction() is called again here + bp_job_state = bp.state + + assert get_batch_prediction_job_mock.call_count == 2 + assert bp_job_state == _TEST_JOB_STATE_SUCCESS + + get_batch_prediction_job_mock.assert_called_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_gcs_output_mock") + def test_batch_prediction_iter_dirs_gcs(self, storage_list_blobs_mock): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + blobs = bp.iter_outputs() + + storage_list_blobs_mock.assert_called_once_with( + _TEST_GCS_OUTPUT_INFO.gcs_output_directory, prefix=None + ) + + assert blobs == _TEST_GCS_BLOBS + + @pytest.mark.usefixtures("get_batch_prediction_job_bq_output_mock") + def test_batch_prediction_iter_dirs_bq(self, bq_list_rows_mock): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + bp.iter_outputs() + + bq_list_rows_mock.assert_called_once_with( + table=f"{_TEST_BQ_DATASET_ID}.predictions", max_results=_TEST_BQ_MAX_RESULTS + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock") + def test_batch_prediction_iter_dirs_while_running(self): + """ + Raises RuntimeError since outputs cannot be read while BatchPredictionJob is still running + """ + with pytest.raises(RuntimeError): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + bp.iter_outputs() + + @pytest.mark.usefixtures("get_batch_prediction_job_empty_output_mock") + def test_batch_prediction_iter_dirs_invalid_output_info(self): + """ + Raises NotImplementedError since the BatchPredictionJob's output_info + contains no output GCS directory or BQ dataset. + """ + with pytest.raises(NotImplementedError): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + bp.iter_outputs() + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_and_dest( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + batch_prediction_job = jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_bq_dest( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + batch_prediction_job = jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL + ), + predictions_format="bigquery", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_with_all_args( + self, create_batch_prediction_job_with_explanations_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = auth_credentials.AnonymousCredentials() + + batch_prediction_job = jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + predictions_format="csv", + model_parameters={}, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + generate_explanation=True, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + labels=_TEST_LABEL, + credentials=creds, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_v1beta1.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), + ), + output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_v1beta1.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="csv", + ), + dedicated_resources=gca_machine_resources_v1beta1.BatchDedicatedResources( + machine_spec=gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ), + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + ), + generate_explanation=True, + explanation_spec=gca_explanation_v1beta1.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + labels=_TEST_LABEL, + ) + + create_batch_prediction_job_with_explanations_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_no_source(self, create_batch_prediction_job_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call without source + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call with two sources + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_no_destination(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call without destination + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + ) + + assert e.match(regexp=r"destination") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_wrong_instance_format(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + instances_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted instances format") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_wrong_prediction_format(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + predictions_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted prediction format") diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py new file mode 100644 index 0000000000..47b000d189 --- /dev/null +++ b/tests/unit/aiplatform/test_models.py @@ -0,0 +1,1130 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +from concurrent import futures +import pytest +from unittest import mock + +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job_v1beta1, + env_var as gca_env_var_v1beta1, + explanation as gca_explanation_v1beta1, + io as gca_io_v1beta1, + model as gca_model_v1beta1, + endpoint as gca_endpoint_v1beta1, + machine_resources as gca_machine_resources_v1beta1, + model_service as gca_model_service_v1beta1, + endpoint_service as gca_endpoint_service_v1beta1, + encryption_spec as gca_encryption_spec_v1beta1, +) + +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client, +) +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, + io as gca_io, + job_state as gca_job_state, + model as gca_model, + endpoint as gca_endpoint, + machine_resources as gca_machine_resources, + model_service as gca_model_service, + endpoint_service as gca_endpoint_service, + encryption_spec as gca_encryption_spec, +) + + +from test_endpoints import create_endpoint_mock # noqa: F401 + +_TEST_PROJECT = "test-project" +_TEST_PROJECT_2 = "test-project-2" +_TEST_LOCATION = "us-central1" +_TEST_LOCATION_2 = "europe-west4" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_MODEL_NAME = "test-model" +_TEST_ARTIFACT_URI = "gs://test/artifact/uri" +_TEST_SERVING_CONTAINER_IMAGE = "gcr.io/test-serving/container:image" +_TEST_SERVING_CONTAINER_PREDICTION_ROUTE = "predict" +_TEST_SERVING_CONTAINER_HEALTH_ROUTE = "metadata" +_TEST_DESCRIPTION = "test description" +_TEST_SERVING_CONTAINER_COMMAND = ["python3", "run_my_model.py"] +_TEST_SERVING_CONTAINER_ARGS = ["--test", "arg"] +_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES = { + "learning_rate": 0.01, + "loss_fn": "mse", +} +_TEST_SERVING_CONTAINER_PORTS = [8888, 10000] +_TEST_ID = "1028944691210842416" +_TEST_LABEL = {"team": "experimentation", "trial_id": "x435"} + +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 +_TEST_STARTING_REPLICA_COUNT = 2 +_TEST_MAX_REPLICA_COUNT = 12 + +_TEST_BATCH_PREDICTION_GCS_SOURCE = "gs://example-bucket/folder/instance.jsonl" +_TEST_BATCH_PREDICTION_GCS_SOURCE_LIST = [ + "gs://example-bucket/folder/instance1.jsonl", + "gs://example-bucket/folder/instance2.jsonl", +] +_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX = "gs://example-bucket/folder/output" +_TEST_BATCH_PREDICTION_BQ_PREFIX = "ucaip-sample-tests" +_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL = ( + f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}" +) +_TEST_BATCH_PREDICTION_DISPLAY_NAME = "test-batch-prediction-job" +_TEST_BATCH_PREDICTION_JOB_NAME = job_service_client.JobServiceClient.batch_prediction_job_path( + project=_TEST_PROJECT, location=_TEST_LOCATION, batch_prediction_job=_TEST_ID +) + +_TEST_INSTANCE_SCHEMA_URI = "gs://test/schema/instance.yaml" +_TEST_PARAMETERS_SCHEMA_URI = "gs://test/schema/parameters.yaml" +_TEST_PREDICTION_SCHEMA_URI = "gs://test/schema/predictions.yaml" + +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) + +_TEST_EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( + inputs={ + "features": { + "input_tensor_name": "dense_input", + "encoding": "BAG_OF_FEATURES", + "modality": "numeric", + "index_feature_mapping": ["abc", "def", "ghj"], + } + }, + outputs={"medv": {"output_tensor_name": "dense_2"}}, +) +_TEST_EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( + {"sampled_shapley_attribution": {"path_count": 10}} +) + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) +_TEST_ENCRYPTION_SPEC_V1BETA1 = gca_encryption_spec_v1beta1.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_RESOURCE_NAME = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID +) +_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID +) +_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID +) + +_TEST_OUTPUT_DIR = "gs://my-output-bucket" + + +@pytest.fixture +def get_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ) + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_MODEL_NAME, name=test_endpoint_resource_name, + ) + yield get_endpoint_mock + + +@pytest.fixture +def get_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_explanations_mock(): + with mock.patch.object( + model_service_client_v1beta1.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model_v1beta1.Model( + display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_custom_location_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_custom_project_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT, + ) + yield get_model_mock + + +@pytest.fixture +def upload_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_explanations_mock(): + with mock.patch.object( + model_service_client_v1beta1.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service_v1beta1.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_custom_project_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_custom_location_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def delete_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "delete_model" + ) as delete_model_mock: + delete_model_lro_mock = mock.Mock(ga_operation.Operation) + delete_model_lro_mock.result.return_value = ( + gca_model_service.DeleteModelRequest() + ) + delete_model_mock.return_value = delete_model_lro_mock + yield delete_model_mock + + +@pytest.fixture +def deploy_model_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint.DeployedModel( + model=_TEST_MODEL_RESOURCE_NAME, display_name=_TEST_MODEL_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def deploy_model_with_explanations_mock(): + with mock.patch.object( + endpoint_service_client_v1beta1.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint_v1beta1.DeployedModel( + model=_TEST_MODEL_RESOURCE_NAME, display_name=_TEST_MODEL_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service_v1beta1.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def get_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + batch_prediction_mock = mock.Mock( + spec=gca_batch_prediction_job.BatchPredictionJob + ) + batch_prediction_mock.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED + batch_prediction_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME + get_batch_prediction_job_mock.return_value = batch_prediction_mock + yield get_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + batch_prediction_job_mock = mock.Mock( + spec=gca_batch_prediction_job.BatchPredictionJob + ) + batch_prediction_job_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME + create_batch_prediction_job_mock.return_value = batch_prediction_job_mock + yield create_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_with_explanations_mock(): + with mock.patch.object( + job_service_client_v1beta1.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + batch_prediction_job_mock = mock.Mock( + spec=gca_batch_prediction_job_v1beta1.BatchPredictionJob + ) + batch_prediction_job_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME + create_batch_prediction_job_mock.return_value = batch_prediction_job_mock + yield create_batch_prediction_job_mock + + +@pytest.fixture +def create_client_mock(): + with mock.patch.object( + initializer.global_config, "create_client" + ) as create_client_mock: + api_client_mock = mock.Mock(spec=model_service_client.ModelServiceClient) + create_client_mock.return_value = api_client_mock + yield create_client_mock + + +class TestModel: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_constructor_creates_client(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + models.Model(_TEST_ID) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION, + prediction_client=False, + ) + + def test_constructor_create_client_with_custom_location(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + models.Model(_TEST_ID, location=_TEST_LOCATION_2) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION_2, + prediction_client=False, + ) + + def test_constructor_creates_client_with_custom_credentials( + self, create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = auth_credentials.AnonymousCredentials() + models.Model(_TEST_ID, credentials=creds) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=False, + ) + + def test_constructor_gets_model(self, get_model_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Model(_TEST_ID) + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) + + def test_constructor_gets_model_with_custom_project(self, get_model_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Model(_TEST_ID, project=_TEST_PROJECT_2) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + get_model_mock.assert_called_once_with(name=test_model_resource_name) + + def test_constructor_gets_model_with_custom_location(self, get_model_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Model(_TEST_ID, location=_TEST_LOCATION_2) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) + get_model_mock.assert_called_once_with(name=test_model_resource_name) + + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model( + self, upload_model_mock, get_model_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, container_spec=container_spec, + ) + + upload_model_mock.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ) + + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) + + def test_upload_raises_with_impartial_explanation_spec(self): + + with pytest.raises(ValueError) as e: + models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS + # Missing the required explanations_metadata field + ) + + assert e.match(regexp=r"`explanation_parameters` should be specified or None.") + + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model_with_all_args( + self, upload_model_with_explanations_mock, get_model_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + instance_schema_uri=_TEST_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_PREDICTION_SCHEMA_URI, + description=_TEST_DESCRIPTION, + serving_container_command=_TEST_SERVING_CONTAINER_COMMAND, + serving_container_args=_TEST_SERVING_CONTAINER_ARGS, + serving_container_environment_variables=_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=_TEST_SERVING_CONTAINER_PORTS, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + sync=sync, + ) + + if not sync: + my_model.wait() + + env = [ + gca_env_var_v1beta1.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model_v1beta1.Port(container_port=port) + for port in _TEST_SERVING_CONTAINER_PORTS + ] + + container_spec = gca_model_v1beta1.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_SERVING_CONTAINER_COMMAND, + args=_TEST_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + managed_model = gca_model_v1beta1.Model( + display_name=_TEST_MODEL_NAME, + description=_TEST_DESCRIPTION, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + predict_schemata=gca_model_v1beta1.PredictSchemata( + instance_schema_uri=_TEST_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_PREDICTION_SCHEMA_URI, + ), + explanation_spec=gca_model_v1beta1.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + ) + + upload_model_with_explanations_mock.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ) + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) + + @pytest.mark.usefixtures("get_model_with_custom_project_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model_with_custom_project( + self, + upload_model_with_custom_project_mock, + get_model_with_custom_project_mock, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + project=_TEST_PROJECT_2, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + ) + + upload_model_with_custom_project_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT_2}/locations/{_TEST_LOCATION}", + model=managed_model, + ) + + get_model_with_custom_project_mock.assert_called_once_with( + name=test_model_resource_name + ) + + @pytest.mark.usefixtures("get_model_with_custom_location_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model_with_custom_location( + self, + upload_model_with_custom_location_mock, + get_model_with_custom_location_mock, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) + + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + location=_TEST_LOCATION_2, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + ) + + upload_model_with_custom_location_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION_2}", + model=managed_model, + ) + + get_model_with_custom_location_mock.assert_called_once_with( + name=test_model_resource_name + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy(self, deploy_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_endpoint = models.Endpoint(_TEST_ID) + + assert test_model.deploy(test_endpoint, sync=sync,) == test_endpoint + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint(self, deploy_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_endpoint = test_model.deploy(sync=sync) + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_with_explanations( + self, deploy_model_with_explanations_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources( + machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 + ) + expected_deployed_model = gca_endpoint_v1beta1.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + ) + deploy_model_with_explanations_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + def test_deploy_raises_with_impartial_explanation_spec(self): + + test_model = models.Model(_TEST_ID) + + with pytest.raises(ValueError) as e: + test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + explanation_metadata=_TEST_EXPLANATION_METADATA, + # Missing required `explanation_parameters` argument + ) + + assert e.match(regexp=r"`explanation_parameters` should be specified or None.") + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_and_dest( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_and_dest( + self, create_batch_prediction_job_mock, sync + ): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_bq_dest( + self, create_batch_prediction_job_mock, sync + ): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL + ), + predictions_format="bigquery", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_with_all_args( + self, create_batch_prediction_job_with_explanations_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + creds = auth_credentials.AnonymousCredentials() + + # Make SDK batch_predict method call passing all arguments + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + predictions_format="csv", + model_parameters={}, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + generate_explanation=True, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + labels=_TEST_LABEL, + credentials=creds, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client_v1beta1.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_v1beta1.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), + ), + output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_v1beta1.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="csv", + ), + dedicated_resources=gca_machine_resources_v1beta1.BatchDedicatedResources( + machine_spec=gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ), + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + ), + generate_explanation=True, + explanation_spec=gca_explanation_v1beta1.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + labels=_TEST_LABEL, + encryption_spec=_TEST_ENCRYPTION_SPEC_V1BETA1, + ) + + create_batch_prediction_job_with_explanations_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_no_source(self, create_batch_prediction_job_mock): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call without source + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call with two sources + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_no_destination(self): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call without destination + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + ) + + assert e.match(regexp=r"destination") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_wrong_instance_format(self): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + instances_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted instances format") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_wrong_prediction_format(self): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + predictions_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted prediction format") + + @pytest.mark.usefixtures("get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_model(self, delete_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_model.delete(sync=sync) + + if not sync: + test_model.wait() + + delete_model_mock.assert_called_once_with(name=test_model.resource_name) + + @pytest.mark.usefixtures("get_model_mock") + def test_print_model(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + assert ( + repr(test_model) + == f"{object.__repr__(test_model)} \nresource name: {test_model.resource_name}" + ) + + @pytest.mark.usefixtures("get_model_mock") + def test_print_model_if_waiting(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + test_model._gca_resource = None + test_model._latest_future = futures.Future() + assert ( + repr(test_model) + == f"{object.__repr__(test_model)} is waiting for upstream dependencies to complete." + ) + + @pytest.mark.usefixtures("get_model_mock") + def test_print_model_if_exception(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + test_model._gca_resource = None + mock_exception = Exception("mock exception") + test_model._exception = mock_exception + assert ( + repr(test_model) + == f"{object.__repr__(test_model)} failed with {str(mock_exception)}" + ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py new file mode 100644 index 0000000000..b5520a5f4c --- /dev/null +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -0,0 +1,3865 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from distutils import core +import functools +import importlib +import pathlib +import pytest +import subprocess +import shutil +import sys +import tarfile +import tempfile +from unittest import mock +from unittest.mock import patch + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) + +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + env_var as gca_env_var, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +from google.cloud import storage +from google.protobuf import json_format +from google.protobuf import struct_pb2 + + +_TEST_BUCKET_NAME = "test-bucket" +_TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder" +_TEST_GCS_PATH = f"{_TEST_BUCKET_NAME}/{_TEST_GCS_PATH_WITHOUT_BUCKET}" +_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/" +_TEST_LOCAL_SCRIPT_FILE_NAME = "____test____script.py" +_TEST_LOCAL_SCRIPT_FILE_PATH = f"path/to/{_TEST_LOCAL_SCRIPT_FILE_NAME}" +_TEST_PYTHON_SOURCE = """ +print('hello world') +""" +_TEST_REQUIREMENTS = ["pandas", "numpy", "tensorflow"] + +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular +_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" +_TEST_TRAINING_CONTAINER_CMD = ["python3", "task.py"] +_TEST_SERVING_CONTAINER_IMAGE = "gcr.io/test-serving/container:image" +_TEST_SERVING_CONTAINER_PREDICTION_ROUTE = "predict" +_TEST_SERVING_CONTAINER_HEALTH_ROUTE = "metadata" + +_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image +_TEST_ANNOTATION_SCHEMA_URI = schema.dataset.annotation.image.classification + +_TEST_BASE_OUTPUT_DIR = "gs://test-base-output-dir" +_TEST_BIGQUERY_DESTINATION = "bq://test-project" +_TEST_RUN_ARGS = ["-v", 0.1, "--test=arg"] +_TEST_REPLICA_COUNT = 1 +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_K80" +_TEST_INVALID_ACCELERATOR_TYPE = "NVIDIA_DOES_NOT_EXIST" +_TEST_ACCELERATOR_COUNT = 1 +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_DEFAULT_TRAINING_FRACTION_SPLIT = 0.8 +_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT = 0.1 +_TEST_DEFAULT_TEST_FRACTION_SPLIT = 0.1 +_TEST_TRAINING_FRACTION_SPLIT = 0.6 +_TEST_VALIDATION_FRACTION_SPLIT = 0.2 +_TEST_TEST_FRACTION_SPLIT = 0.2 +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_ID = "12345" +_TEST_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipelines/{_TEST_ID}" +) +_TEST_ALT_PROJECT = "test-project-alt" +_TEST_ALT_LOCATION = "europe-west4" + +_TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml" +_TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml" +_TEST_MODEL_PREDICTION_SCHEMA_URI = "prediction_schema_uri.yaml" +_TEST_MODEL_SERVING_CONTAINER_COMMAND = ["test_command"] +_TEST_MODEL_SERVING_CONTAINER_ARGS = ["test_args"] +_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES = { + "learning_rate": 0.01, + "loss_fn": "mse", +} +_TEST_MODEL_SERVING_CONTAINER_PORTS = [8888, 10000] +_TEST_MODEL_DESCRIPTION = "test description" + +_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" +_TEST_PYTHON_MODULE_NAME = "aiplatform.task" + +_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345" + +_TEST_PIPELINE_RESOURCE_NAME = ( + "projects/my-project/locations/us-central1/trainingPipeline/12345" +) +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +def local_copy_method(path): + shutil.copy(path, ".") + return pathlib.Path(path).name + + +@pytest.fixture +def get_training_job_custom_mock(): + with patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as get_training_job_custom_mock: + get_training_job_custom_mock.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + training_task_definition=schema.training_job.definition.custom_task, + ) + + yield get_training_job_custom_mock + + +@pytest.fixture +def get_training_job_custom_mock_no_model_to_upload(): + with patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as get_training_job_custom_mock: + get_training_job_custom_mock.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=None, + training_task_definition=schema.training_job.definition.custom_task, + ) + + yield get_training_job_custom_mock + + +@pytest.fixture +def get_training_job_tabular_mock(): + with patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as get_training_job_tabular_mock: + get_training_job_tabular_mock.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + training_task_definition=schema.training_job.definition.automl_tabular, + ) + + yield get_training_job_tabular_mock + + +@pytest.fixture +def mock_client_bucket(): + with patch.object(storage.Client, "bucket") as mock_client_bucket: + + def blob_side_effect(name, mock_blob, bucket): + mock_blob.name = name + mock_blob.bucket = bucket + return mock_blob + + MockBucket = mock.Mock(autospec=storage.Bucket) + MockBucket.name = _TEST_BUCKET_NAME + MockBlob = mock.Mock(autospec=storage.Blob) + MockBucket.blob.side_effect = functools.partial( + blob_side_effect, mock_blob=MockBlob, bucket=MockBucket + ) + mock_client_bucket.return_value = MockBucket + + yield mock_client_bucket, MockBlob + + +class TestTrainingScriptPythonPackagerHelpers: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def test_timestamp_copy_to_gcs_calls_gcs_client_with_bucket( + self, mock_client_bucket + ): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + ) + + local_script_file_name = pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + blob_arg = mock_client_bucket.return_value.blob.call_args[0][0] + assert blob_arg.startswith("aiplatform-") + assert blob_arg.endswith(_TEST_LOCAL_SCRIPT_FILE_NAME) + + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + assert gcs_path.endswith(local_script_file_name) + assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}/aiplatform-") + + def test_timestamp_copy_to_gcs_calls_gcs_client_with_gcs_path( + self, mock_client_bucket + ): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_GCS_PATH_WITH_TRAILING_SLASH, + project=_TEST_PROJECT, + ) + + local_script_file_name = pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + blob_arg = mock_client_bucket.return_value.blob.call_args[0][0] + assert blob_arg.startswith(f"{_TEST_GCS_PATH_WITHOUT_BUCKET}/aiplatform-") + assert blob_arg.endswith(f"{_TEST_LOCAL_SCRIPT_FILE_NAME}") + + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + + assert gcs_path.startswith(f"gs://{_TEST_GCS_PATH}/aiplatform-") + assert gcs_path.endswith(local_script_file_name) + + def test_timestamp_copy_to_gcs_calls_gcs_client_with_trailing_slash( + self, mock_client_bucket + ): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_GCS_PATH, + project=_TEST_PROJECT, + ) + + local_script_file_name = pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + blob_arg = mock_client_bucket.return_value.blob.call_args[0][0] + assert blob_arg.startswith(f"{_TEST_GCS_PATH_WITHOUT_BUCKET}/aiplatform-") + assert blob_arg.endswith(_TEST_LOCAL_SCRIPT_FILE_NAME) + + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + + assert gcs_path.startswith(f"gs://{_TEST_GCS_PATH}/aiplatform-") + assert gcs_path.endswith(local_script_file_name) + + def test_timestamp_copy_to_gcs_calls_gcs_client(self, mock_client_bucket): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + ) + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + assert gcs_path.endswith(pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name) + assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}") + + def test_get_python_executable_raises_if_None(self): + with patch.object(sys, "executable", new=None): + with pytest.raises(EnvironmentError): + training_jobs._get_python_executable() + + def test_get_python_executable_returns_python_executable(self): + assert "python" in training_jobs._get_python_executable().lower() + + +class TestTrainingScriptPythonPackager: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + with open(_TEST_LOCAL_SCRIPT_FILE_NAME, "w") as fp: + fp.write(_TEST_PYTHON_SOURCE) + + def teardown_method(self): + pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_NAME).unlink() + python_package_file = f"{training_jobs._TrainingScriptPythonPackager._ROOT_MODULE}-{training_jobs._TrainingScriptPythonPackager._SETUP_PY_VERSION}.tar.gz" + if pathlib.Path(python_package_file).is_file(): + pathlib.Path(python_package_file).unlink() + subprocess.check_output( + [ + "pip3", + "uninstall", + "-y", + training_jobs._TrainingScriptPythonPackager._ROOT_MODULE, + ] + ) + + def test_packager_creates_and_copies_python_package(self): + tsp = training_jobs._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_NAME) + tsp.package_and_copy(copy_method=local_copy_method) + assert pathlib.Path( + f"{tsp._ROOT_MODULE}-{tsp._SETUP_PY_VERSION}.tar.gz" + ).is_file() + + def test_created_package_module_is_installable_and_can_be_run(self): + tsp = training_jobs._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_NAME) + source_dist_path = tsp.package_and_copy(copy_method=local_copy_method) + subprocess.check_output(["pip3", "install", source_dist_path]) + module_output = subprocess.check_output( + [training_jobs._get_python_executable(), "-m", tsp.module_name] + ) + assert "hello world" in module_output.decode() + + def test_requirements_are_in_package(self): + tsp = training_jobs._TrainingScriptPythonPackager( + _TEST_LOCAL_SCRIPT_FILE_NAME, requirements=_TEST_REQUIREMENTS + ) + source_dist_path = tsp.package_and_copy(copy_method=local_copy_method) + with tarfile.open(source_dist_path) as tf: + with tempfile.TemporaryDirectory() as tmpdirname: + setup_py_path = f"{training_jobs._TrainingScriptPythonPackager._ROOT_MODULE}-{training_jobs._TrainingScriptPythonPackager._SETUP_PY_VERSION}/setup.py" + tf.extract(setup_py_path, path=tmpdirname) + setup_py = core.run_setup( + pathlib.Path(tmpdirname, setup_py_path), stop_after="init" + ) + assert _TEST_REQUIREMENTS == setup_py.install_requires + + def test_packaging_fails_whith_RuntimeError(self): + with patch("subprocess.Popen") as mock_popen: + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 1 + mock_popen.return_value = mock_subprocess + tsp = training_jobs._TrainingScriptPythonPackager( + _TEST_LOCAL_SCRIPT_FILE_NAME + ) + with pytest.raises(RuntimeError): + tsp.package_and_copy(copy_method=local_copy_method) + + def test_package_and_copy_to_gcs_copies_to_gcs(self, mock_client_bucket): + mock_client_bucket, mock_blob = mock_client_bucket + + tsp = training_jobs._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_NAME) + + gcs_path = tsp.package_and_copy_to_gcs( + gcs_staging_dir=_TEST_BUCKET_NAME, project=_TEST_PROJECT + ) + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + mock_blob.upload_from_filename.call_args[0][0].endswith( + "/trainer/dist/aiplatform_custom_trainer_script-0.1.tar.gz" + ) + + assert gcs_path.endswith("-aiplatform_custom_trainer_script-0.1.tar.gz") + assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}") + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_cancel(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "cancel_training_pipeline" + ) as mock_cancel_training_pipeline: + yield mock_cancel_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_with_no_model_to_upload(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get_with_no_model_to_upload(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model(name=_TEST_MODEL_NAME) + yield mock_get_model + + +@pytest.fixture +def mock_python_package_to_gcs(): + with mock.patch.object( + training_jobs._TrainingScriptPythonPackager, "package_and_copy_to_gcs" + ) as mock_package_to_copy_gcs: + mock_package_to_copy_gcs.return_value = _TEST_OUTPUT_PYTHON_PACKAGE_PATH + yield mock_package_to_copy_gcs + + +@pytest.fixture +def mock_tabular_dataset(): + ds = mock.MagicMock(datasets.TabularDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_nontabular_dataset(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +class TestCustomTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + with open(_TEST_LOCAL_SCRIPT_FILE_NAME, "w") as fp: + fp.write(_TEST_PYTHON_SOURCE) + + def teardown_method(self): + pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_NAME).unlink() + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_bigquery_destination( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + bigquery_destination=_TEST_BIGQUERY_DESTINATION, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BIGQUERY_DESTINATION + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_invalid_accelerator_type_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_incomplete_model_info_raises_with_model_to_upload( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_no_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, container_spec=true_container_spec + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_returns_none_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + model = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + assert model is None + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_get_model_raises_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_python_package_to_gcs, + mock_tabular_dataset, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called( + self, mock_pipeline_service_create, mock_python_package_to_gcs + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state + + def test_run_raises_if_no_staging_bucket(self): + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(RuntimeError): + training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_distributed_training( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = [ + { + "replicaCount": 1, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + { + "replicaCount": 9, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + ] + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": true_worker_pool_spec, + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("get_training_job_custom_mock") + def test_get_training_job(self, get_training_job_custom_mock): + aiplatform.init(project=_TEST_PROJECT) + job = training_jobs.CustomTrainingJob.get(resource_name=_TEST_NAME) + + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + assert isinstance(job, training_jobs.CustomTrainingJob) + + @pytest.mark.usefixtures("get_training_job_custom_mock") + def test_get_training_job_wrong_job_type(self, get_training_job_custom_mock): + aiplatform.init(project=_TEST_PROJECT) + + # The returned job is for a custom training task, + # but the calling type if of AutoMLImageTrainingJob. + # Hence, it should throw an error. + with pytest.raises(ValueError): + training_jobs.AutoMLImageTrainingJob.get(resource_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_training_job_custom_mock_no_model_to_upload") + def test_get_training_job_no_model_to_upload( + self, get_training_job_custom_mock_no_model_to_upload + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.CustomTrainingJob.get(resource_name=_TEST_NAME) + + with pytest.raises(RuntimeError): + job.get_model(sync=False) + + @pytest.mark.usefixtures("get_training_job_tabular_mock") + def test_get_training_job_tabular(self, get_training_job_tabular_mock): + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(ValueError): + training_jobs.CustomTrainingJob.get(resource_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_training_job_custom_mock") + def test_get_training_job_with_id_only(self, get_training_job_custom_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get(resource_name=_TEST_ID) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_id_only_with_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_alt_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_project_and_alt_location(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(RuntimeError): + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, + project=_TEST_PROJECT, + location=_TEST_ALT_LOCATION, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_nontabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_nontabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_nontabular_dataset, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_nontabular_dataset.name, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + def test_run_call_pipeline_service_create_with_nontabular_dataset_raises_if_annotation_schema_uri( + self, mock_nontabular_dataset, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + with pytest.raises(Exception): + job.run( + dataset=mock_nontabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + def test_cancel_training_job(self, mock_pipeline_service_cancel): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run() + job.cancel() + + mock_pipeline_service_cancel.assert_called_once_with( + name=_TEST_PIPELINE_RESOURCE_NAME + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + def test_cancel_training_job_without_running(self, mock_pipeline_service_cancel): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError) as e: + job.cancel() + + assert e.match(regexp=r"TrainingJob has not been launched") + + +class TestCustomContainerTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_bigquery_destination( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + bigquery_destination=_TEST_BIGQUERY_DESTINATION, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BIGQUERY_DESTINATION + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_invalid_accelerator_type_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_incomplete_model_info_raises_with_model_to_upload( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_no_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, container_spec=true_container_spec + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_returns_none_if_no_model_to_upload( + self, + mock_pipeline_service_create_with_no_model_to_upload, + mock_pipeline_service_get_with_no_model_to_upload, + mock_tabular_dataset, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + model = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + assert model is None + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_get_model_raises_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_tabular_dataset, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state + + def test_run_raises_if_no_staging_bucket(self): + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(RuntimeError): + training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_distributed_training( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = [ + { + "replicaCount": 1, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + }, + { + "replicaCount": 9, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + }, + ] + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": true_worker_pool_spec, + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_nontabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_nontabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_nontabular_dataset, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_nontabular_dataset.name, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + def test_run_call_pipeline_service_create_with_nontabular_dataset_raises_if_annotation_schema_uri( + self, mock_nontabular_dataset, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + with pytest.raises(Exception): + job.run( + dataset=mock_nontabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + ) + + +class Test_MachineSpec: + def test_machine_spec_return_spec_dict(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ) + + true_spec_dict = { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": _TEST_REPLICA_COUNT, + } + + assert test_spec.spec_dict == true_spec_dict + + def test_machine_spec_return_spec_dict_with_no_accelerator(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=0, + accelerator_type="ACCELERATOR_TYPE_UNSPECIFIED", + ) + + true_spec_dict = { + "machineSpec": {"machineType": _TEST_MACHINE_TYPE}, + "replicaCount": _TEST_REPLICA_COUNT, + } + + assert test_spec.spec_dict == true_spec_dict + + def test_machine_spec_spec_dict_raises_invalid_accelerator(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + ) + + with pytest.raises(ValueError): + test_spec.spec_dict + + def test_machine_spec_spec_dict_is_empty(self): + test_spec = training_jobs._MachineSpec( + replica_count=0, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + ) + + assert test_spec.is_empty + + def test_machine_spec_spec_dict_is_not_empty(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + ) + + assert not test_spec.is_empty + + +class Test_DistributedTrainingSpec: + def test_machine_spec_returns_pool_spec(self): + + spec = training_jobs._DistributedTrainingSpec( + chief_spec=training_jobs._MachineSpec( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + worker_spec=training_jobs._MachineSpec( + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + parameter_server_spec=training_jobs._MachineSpec( + replica_count=3, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + evaluator_spec=training_jobs._MachineSpec( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 10, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 3, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + ] + + assert spec.pool_specs == true_pool_spec + + def test_chief_worker_pool_returns_spec(self): + + chief_worker_spec = training_jobs._DistributedTrainingSpec.chief_worker_pool( + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 9, + }, + ] + + assert chief_worker_spec.pool_specs == true_pool_spec + + def test_chief_worker_pool_returns_just_chief(self): + + chief_worker_spec = training_jobs._DistributedTrainingSpec.chief_worker_pool( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + } + ] + + assert chief_worker_spec.pool_specs == true_pool_spec + + def test_machine_spec_raise_with_more_than_one_chief_replica(self): + + spec = training_jobs._DistributedTrainingSpec( + chief_spec=training_jobs._MachineSpec( + replica_count=2, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + ) + + with pytest.raises(ValueError): + spec.pool_specs + + def test_machine_spec_handles_missing_pools(self): + + spec = training_jobs._DistributedTrainingSpec( + chief_spec=training_jobs._MachineSpec( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + worker_spec=training_jobs._MachineSpec(replica_count=0), + parameter_server_spec=training_jobs._MachineSpec( + replica_count=3, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + evaluator_spec=training_jobs._MachineSpec(replica_count=0), + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + {"machineSpec": {"machineType": "n1-standard-4"}, "replicaCount": 0}, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 3, + }, + ] + + assert spec.pool_specs == true_pool_spec + + +class TestCustomPythonPackageTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_bigquery_destination( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + bigquery_destination=_TEST_BIGQUERY_DESTINATION, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BIGQUERY_DESTINATION + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_invalid_accelerator_type_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_incomplete_model_info_raises_with_model_to_upload( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_no_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, container_spec=true_container_spec + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_returns_none_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + model = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + assert model is None + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_get_model_raises_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_tabular_dataset, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state + + def test_run_raises_if_no_staging_bucket(self): + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(RuntimeError): + training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_distributed_training( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = [ + { + "replicaCount": 1, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + { + "replicaCount": 9, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + ] + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": true_worker_pool_spec, + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_nontabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_nontabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_nontabular_dataset, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_nontabular_dataset.name, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + def test_run_call_pipeline_service_create_with_nontabular_dataset_raises_if_annotation_schema_uri( + self, mock_nontabular_dataset, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + with pytest.raises(Exception): + job.run( + dataset=mock_nontabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + ) diff --git a/tests/unit/aiplatform/test_training_utils.py b/tests/unit/aiplatform/test_training_utils.py new file mode 100644 index 0000000000..1d4b839151 --- /dev/null +++ b/tests/unit/aiplatform/test_training_utils.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os +import pytest + +from google.cloud.aiplatform import training_utils +from unittest import mock + +_TEST_TRAINING_DATA_URI = "gs://training-data-uri" +_TEST_VALIDATION_DATA_URI = "gs://test-validation-data-uri" +_TEST_TEST_DATA_URI = "gs://test-data-uri" +_TEST_MODEL_DIR = "gs://test-model-dir" +_TEST_CHECKPOINT_DIR = "gs://test-checkpoint-dir" +_TEST_TENSORBOARD_LOG_DIR = "gs://test-tensorboard-log-dir" +_TEST_CLUSTER_SPEC = """{ + "cluster": { + "worker_pools":[ + { + "index":0, + "replicas":[ + "training-workerpool0-ab-0:2222" + ] + }, + { + "index":1, + "replicas":[ + "training-workerpool1-ab-0:2222", + "training-workerpool1-ab-1:2222" + ] + } + ] + }, + "environment": "cloud", + "task": { + "worker_pool_index":0, + "replica_index":0, + "trial":"TRIAL_ID" + } +}""" + + +class TestTrainingUtils: + @pytest.fixture + def mock_environment(self): + env_vars = { + "AIP_TRAINING_DATA_URI": _TEST_TRAINING_DATA_URI, + "AIP_VALIDATION_DATA_URI": _TEST_VALIDATION_DATA_URI, + "AIP_TEST_DATA_URI": _TEST_TEST_DATA_URI, + "AIP_MODEL_DIR": _TEST_MODEL_DIR, + "AIP_CHECKPOINT_DIR": _TEST_CHECKPOINT_DIR, + "AIP_TENSORBOARD_LOG_DIR": _TEST_TENSORBOARD_LOG_DIR, + "CLUSTER_SPEC": _TEST_CLUSTER_SPEC, + "TF_CONFIG": _TEST_CLUSTER_SPEC, + } + with mock.patch.dict(os.environ, env_vars): + yield + + @pytest.mark.usefixtures("mock_environment") + def test_training_data_uri(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.training_data_uri == _TEST_TRAINING_DATA_URI + + def test_training_data_uri_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.training_data_uri is None + + @pytest.mark.usefixtures("mock_environment") + def test_validation_data_uri(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.validation_data_uri == _TEST_VALIDATION_DATA_URI + + def test_validation_data_uri_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.validation_data_uri is None + + @pytest.mark.usefixtures("mock_environment") + def test_test_data_uri(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.test_data_uri == _TEST_TEST_DATA_URI + + def test_test_data_uri_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.test_data_uri is None + + @pytest.mark.usefixtures("mock_environment") + def test_model_dir(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.model_dir == _TEST_MODEL_DIR + + def test_model_dir_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.model_dir is None + + @pytest.mark.usefixtures("mock_environment") + def test_checkpoint_dir(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.checkpoint_dir == _TEST_CHECKPOINT_DIR + + def test_checkpoint_dir_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.checkpoint_dir is None + + @pytest.mark.usefixtures("mock_environment") + def test_tensorboard_log_dir(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tensorboard_log_dir == _TEST_TENSORBOARD_LOG_DIR + + def test_tensorboard_log_dir_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tensorboard_log_dir is None + + @pytest.mark.usefixtures("mock_environment") + def test_cluster_spec(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.cluster_spec == json.loads(_TEST_CLUSTER_SPEC) + + def test_cluster_spec_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.cluster_spec is None + + @pytest.mark.usefixtures("mock_environment") + def test_tf_config(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tf_config == json.loads(_TEST_CLUSTER_SPEC) + + def test_tf_config_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tf_config is None diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py new file mode 100644 index 0000000000..3032475069 --- /dev/null +++ b/tests/unit/aiplatform/test_utils.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from uuid import uuid4 +from random import choice +from random import randint +from string import ascii_letters + +from google.api_core import client_options +from google.api_core import gapic_v1 +from google.cloud import aiplatform +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client_v1, +) + +model_service_client_default = model_service_client_v1 + + +@pytest.mark.parametrize( + "resource_name, expected", + [ + ("projects/123456/locations/us-central1/datasets/987654", True), + ("projects/857392/locations/us-central1/trainingPipelines/347292", True), + ("projects/acme-co-proj-1/locations/us-central1/datasets/123456", True), + ("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", False), + ("project/123456/locations/us-central1/datasets/987654", False), + ("project//locations//datasets/987654", False), + ("locations/europe-west4/datasets/987654", False), + ("987654", False), + ], +) +def test_extract_fields_from_resource_name(resource_name: str, expected: bool): + # Given a resource name and expected validity, test extract_fields_from_resource_name() + assert expected == bool(utils.extract_fields_from_resource_name(resource_name)) + + +@pytest.fixture +def generated_resource_fields(): + generated_fields = utils.Fields( + project=str(uuid4()), + location=str(uuid4()), + resource="".join(choice(ascii_letters) for i in range(10)), # 10 random letters + id=str(randint(0, 100000)), + ) + + yield generated_fields + + +@pytest.fixture +def generated_resource_name(generated_resource_fields: utils.Fields): + name = ( + f"projects/{generated_resource_fields.project}/" + f"locations/{generated_resource_fields.location}" + f"/{generated_resource_fields.resource}/{generated_resource_fields.id}" + ) + + yield name + + +def test_extract_fields_from_resource_name_with_extracted_fields( + generated_resource_name: str, generated_resource_fields: utils.Fields +): + """Verify fields extracted from resource name match the original fields""" + + assert ( + utils.extract_fields_from_resource_name(resource_name=generated_resource_name) + == generated_resource_fields + ) + + +@pytest.mark.parametrize( + "resource_name, resource_noun, expected", + [ + # Expects pattern "projects/.../locations/.../datasets/..." + ("projects/123456/locations/us-central1/datasets/987654", "datasets", True), + # Expects pattern "projects/.../locations/.../batchPredictionJobs/..." + ( + "projects/857392/locations/us-central1/trainingPipelines/347292", + "batchPredictionJobs", + False, + ), + ], +) +def test_extract_fields_from_resource_name_with_resource_noun( + resource_name: str, resource_noun: str, expected: bool +): + assert ( + bool( + utils.extract_fields_from_resource_name( + resource_name=resource_name, resource_noun=resource_noun + ) + ) + == expected + ) + + +def test_invalid_region_raises_with_invalid_region(): + with pytest.raises(ValueError): + aiplatform.utils.validate_region(region="us-west4") + + +def test_invalid_region_does_not_raise_with_valid_region(): + aiplatform.utils.validate_region(region="us-central1") + + +@pytest.mark.parametrize( + "resource_noun, project, location, full_name", + [ + ( + "datasets", + "123456", + "us-central1", + "projects/123456/locations/us-central1/datasets/987654", + ), + ( + "trainingPipelines", + "857392", + "us-west20", + "projects/857392/locations/us-central1/trainingPipelines/347292", + ), + ], +) +def test_full_resource_name_with_full_name( + resource_noun: str, project: str, location: str, full_name: str, +): + # should ignore issues with other arguments as resource_name is full_name + assert ( + aiplatform.utils.full_resource_name( + resource_name=full_name, + resource_noun=resource_noun, + project=project, + location=location, + ) + == full_name + ) + + +@pytest.mark.parametrize( + "partial_name, resource_noun, project, location, full_name", + [ + ( + "987654", + "datasets", + "123456", + "us-central1", + "projects/123456/locations/us-central1/datasets/987654", + ), + ( + "347292", + "trainingPipelines", + "857392", + "us-central1", + "projects/857392/locations/us-central1/trainingPipelines/347292", + ), + ], +) +def test_full_resource_name_with_partial_name( + partial_name: str, resource_noun: str, project: str, location: str, full_name: str, +): + assert ( + aiplatform.utils.full_resource_name( + resource_name=partial_name, + resource_noun=resource_noun, + project=project, + location=location, + ) + == full_name + ) + + +@pytest.mark.parametrize( + "partial_name, resource_noun, project, location", + [("347292", "trainingPipelines", "857392", "us-west2020")], +) +def test_full_resource_name_raises_value_error( + partial_name: str, resource_noun: str, project: str, location: str, +): + with pytest.raises(ValueError): + aiplatform.utils.full_resource_name( + resource_name=partial_name, + resource_noun=resource_noun, + project=project, + location=location, + ) + + +def test_validate_display_name_raises_length(): + with pytest.raises(ValueError): + aiplatform.utils.validate_display_name( + "slanflksdnlikh;likhq290u90rflkasndfkljashndfkl;jhowq2342;iehoiwerhowqihjer34564356o;iqwjr;oijsdalfjasl;kfjas;ldifhja;slkdfsdlkfhj" + ) + + +def test_validate_display_name(): + aiplatform.utils.validate_display_name("my_model_abc") + + +@pytest.mark.parametrize( + "accelerator_type, expected", + [ + ("NVIDIA_TESLA_K80", True), + ("ACCELERATOR_TYPE_UNSPECIFIED", True), + ("NONEXISTENT_GPU", False), + ("NVIDIA_GALAXY_R7", False), + ("", False), + (None, False), + ], +) +def test_validate_accelerator_type(accelerator_type: str, expected: bool): + # Invalid type raises specific ValueError + if not expected: + with pytest.raises(ValueError) as e: + utils.validate_accelerator_type(accelerator_type) + assert e.match(regexp=r"Given accelerator_type") + # Valid type returns True + else: + assert utils.validate_accelerator_type(accelerator_type) + + +@pytest.mark.parametrize( + "gcs_path, expected", + [ + ("gs://example-bucket/path/to/folder", ("example-bucket", "path/to/folder")), + ("example-bucket/path/to/folder/", ("example-bucket", "path/to/folder")), + ("gs://example-bucket", ("example-bucket", None)), + ("gs://example-bucket/", ("example-bucket", None)), + ("gs://example-bucket/path", ("example-bucket", "path")), + ], +) +def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple): + # Given a GCS path, ensure correct bucket and prefix are extracted + assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path) + + +def test_wrapped_client(): + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + wrapped_client = utils.ClientWithOverride.WrappedClient( + client_class=model_service_client_default.ModelServiceClient, + client_options=test_client_options, + client_info=test_client_info, + ) + + assert isinstance( + wrapped_client.get_model.__self__, + model_service_client_default.ModelServiceClient, + ) + + +def test_client_w_override_default_version(): + + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + client_w_override = utils.ModelClientWithOverride( + client_options=test_client_options, client_info=test_client_info, + ) + assert isinstance( + client_w_override._clients[ + client_w_override._default_version + ].get_model.__self__, + model_service_client_default.ModelServiceClient, + ) + + +def test_client_w_override_select_version(): + + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + client_w_override = utils.ModelClientWithOverride( + client_options=test_client_options, client_info=test_client_info, + ) + + assert isinstance( + client_w_override.select_version(compat.V1BETA1).get_model.__self__, + model_service_client_v1beta1.ModelServiceClient, + ) + assert isinstance( + client_w_override.select_version(compat.V1).get_model.__self__, + model_service_client_v1.ModelServiceClient, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index ecf6f7286a..3fe62e7836 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -1420,19 +1420,17 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - location = "mussel" - dataset = "winkle" + dataset = "mussel" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1442,9 +1440,9 @@ def test_parse_dataset_path(): assert expected == actual def test_dataset_path(): - project = "squid" - location = "clam" - dataset = "whelk" + project = "scallop" + location = "abalone" + dataset = "squid" expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = MigrationServiceClient.dataset_path(project, location, dataset) @@ -1453,9 +1451,9 @@ def test_dataset_path(): def test_parse_dataset_path(): expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", + "project": "clam", + "location": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1465,17 +1463,19 @@ def test_parse_dataset_path(): assert expected == actual def test_dataset_path(): - project = "cuttlefish" - dataset = "mussel" + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", + "project": "mussel", + "location": "winkle", "dataset": "nautilus", } diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 134e0632c7..51d76cb3c4 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -1420,17 +1420,19 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - dataset = "mussel" + location = "mussel" + dataset = "winkle" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1440,19 +1442,17 @@ def test_parse_dataset_path(): assert expected == actual def test_dataset_path(): - project = "scallop" - location = "abalone" - dataset = "squid" + project = "squid" + dataset = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", + "project": "whelk", "dataset": "octopus", } From 1962ecc322de4c8a71870079e393100138b827c5 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Tue, 20 Apr 2021 08:48:55 -0400 Subject: [PATCH 12/36] chore: restore metadata experiment testing --- tests/unit/aiplatform/test_initializer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 1d97ad2e9a..84498d0a37 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -19,11 +19,13 @@ import os import pytest from unittest import mock +from unittest.mock import patch import google.auth from google.auth import credentials from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata.metadata import metadata_service from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils @@ -69,9 +71,10 @@ def test_init_location_with_invalid_location_raises(self): with pytest.raises(ValueError): initializer.global_config.init(location=_TEST_INVALID_LOCATION) - def test_init_experiment_sets_experiment(self): + @patch.object(metadata_service, "set_experiment") + def test_init_experiment_sets_experiment(self, set_experiment_mock): initializer.global_config.init(experiment=_TEST_EXPERIMENT) - assert initializer.global_config.experiment == _TEST_EXPERIMENT + set_experiment_mock.assert_called_once_with(_TEST_EXPERIMENT) def test_init_staging_bucket_sets_staging_bucket(self): initializer.global_config.init(staging_bucket=_TEST_STAGING_BUCKET) From 10fb38a5ed4d24e2e108a501b549184e32f33862 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Tue, 20 Apr 2021 09:23:55 -0400 Subject: [PATCH 13/36] chore:lint --- .../services/dataset_service/client.py | 543 +- .../dataset_service/transports/base.py | 223 +- .../services/endpoint_service/client.py | 400 +- .../endpoint_service/transports/base.py | 166 +- .../services/job_service/client.py | 954 ++- .../services/job_service/transports/base.py | 351 +- .../services/migration_service/client.py | 282 +- .../services/model_service/client.py | 555 +- .../services/model_service/transports/base.py | 214 +- .../services/pipeline_service/client.py | 329 +- .../prediction_service/transports/base.py | 70 +- .../specialist_pool_service/client.py | 309 +- .../transports/base.py | 121 +- google/cloud/aiplatform_v1beta1/__init__.py | 730 +-- .../services/dataset_service/client.py | 543 +- .../services/endpoint_service/client.py | 400 +- .../__init__.py | 4 +- .../async_client.py | 176 +- .../client.py | 236 +- .../transports/__init__.py | 16 +- .../transports/base.py | 85 +- .../transports/grpc.py | 99 +- .../transports/grpc_asyncio.py | 100 +- .../services/featurestore_service/__init__.py | 4 +- .../featurestore_service/async_client.py | 769 ++- .../services/featurestore_service/client.py | 859 ++- .../services/featurestore_service/pagers.py | 157 +- .../transports/__init__.py | 14 +- .../featurestore_service/transports/base.py | 359 +- .../featurestore_service/transports/grpc.py | 357 +- .../transports/grpc_asyncio.py | 376 +- .../index_endpoint_service/__init__.py | 4 +- .../index_endpoint_service/async_client.py | 336 +- .../services/index_endpoint_service/client.py | 411 +- .../services/index_endpoint_service/pagers.py | 51 +- .../transports/__init__.py | 14 +- .../index_endpoint_service/transports/base.py | 155 +- .../index_endpoint_service/transports/grpc.py | 173 +- .../transports/grpc_asyncio.py | 179 +- .../services/index_service/__init__.py | 4 +- .../services/index_service/async_client.py | 247 +- .../services/index_service/client.py | 316 +- .../services/index_service/pagers.py | 45 +- .../index_service/transports/__init__.py | 10 +- .../services/index_service/transports/base.py | 134 +- .../services/index_service/transports/grpc.py | 133 +- .../index_service/transports/grpc_asyncio.py | 134 +- .../services/job_service/client.py | 1386 ++--- .../services/metadata_service/async_client.py | 34 +- .../services/metadata_service/client.py | 1179 ++-- .../metadata_service/transports/base.py | 468 +- .../metadata_service/transports/grpc.py | 19 +- .../transports/grpc_asyncio.py | 19 +- .../services/migration_service/client.py | 282 +- .../services/model_service/client.py | 555 +- .../services/pipeline_service/client.py | 333 +- .../specialist_pool_service/client.py | 309 +- .../aiplatform_v1beta1/types/__init__.py | 802 ++- .../types/deployed_index_ref.py | 5 +- .../aiplatform_v1beta1/types/entity_type.py | 17 +- .../cloud/aiplatform_v1beta1/types/feature.py | 26 +- .../types/feature_selector.py | 11 +- .../aiplatform_v1beta1/types/featurestore.py | 22 +- .../types/featurestore_monitoring.py | 15 +- .../types/featurestore_online_service.py | 115 +- .../types/featurestore_service.py | 235 +- .../cloud/aiplatform_v1beta1/types/index.py | 21 +- .../types/index_endpoint.py | 47 +- .../types/index_endpoint_service.py | 68 +- .../aiplatform_v1beta1/types/index_service.py | 74 +- google/cloud/aiplatform_v1beta1/types/io.py | 36 +- .../types/metadata_service.py | 150 +- .../cloud/aiplatform_v1beta1/types/types.py | 9 +- .../aiplatform_v1/test_migration_service.py | 940 +-- ...est_featurestore_online_serving_service.py | 762 ++- .../test_featurestore_service.py | 3269 +++++------ .../test_index_endpoint_service.py | 1534 ++--- .../aiplatform_v1beta1/test_index_service.py | 1277 ++--- .../test_metadata_service.py | 5092 ++++++++--------- .../test_migration_service.py | 944 +-- 80 files changed, 15857 insertions(+), 16345 deletions(-) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index 3e14ad0e50..3868a97a4d 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,13 +60,14 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry['grpc'] = DatasetServiceGrpcTransport - _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[DatasetServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry["grpc"] = DatasetServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +118,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +153,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: DatasetServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,110 +169,149 @@ def transport(self) -> DatasetServiceTransport: return self._transport @staticmethod - def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: + def annotation_path( + project: str, location: str, dataset: str, data_item: str, annotation: str, + ) -> str: """Return a fully-qualified annotation string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) @staticmethod - def parse_annotation_path(path: str) -> Dict[str,str]: + def parse_annotation_path(path: str) -> Dict[str, str]: """Parse a annotation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: + def annotation_spec_path( + project: str, location: str, dataset: str, annotation_spec: str, + ) -> str: """Return a fully-qualified annotation_spec string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) @staticmethod - def parse_annotation_spec_path(path: str) -> Dict[str,str]: + def parse_annotation_spec_path(path: str) -> Dict[str, str]: """Parse a annotation_spec path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: + def data_item_path( + project: str, location: str, dataset: str, data_item: str, + ) -> str: """Return a fully-qualified data_item string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) @staticmethod - def parse_data_item_path(path: str) -> Dict[str,str]: + def parse_data_item_path(path: str) -> Dict[str, str]: """Parse a data_item path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -316,7 +355,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -326,7 +367,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -338,7 +381,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -350,8 +395,10 @@ def __init__(self, *, if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -370,15 +417,16 @@ def __init__(self, *, client_info=client_info, ) - def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a Dataset. Args: @@ -419,8 +467,10 @@ def create_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.CreateDatasetRequest. @@ -444,18 +494,11 @@ def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -468,14 +511,15 @@ def create_dataset(self, # Done; return the response. return response - def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -507,8 +551,10 @@ def get_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetDatasetRequest. @@ -530,31 +576,25 @@ def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -599,8 +639,10 @@ def update_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.UpdateDatasetRequest. @@ -624,30 +666,26 @@ def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -682,8 +720,10 @@ def list_datasets(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDatasetsRequest. @@ -705,39 +745,30 @@ def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a Dataset. Args: @@ -783,8 +814,10 @@ def delete_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.DeleteDatasetRequest. @@ -806,18 +839,11 @@ def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -830,15 +856,16 @@ def delete_dataset(self, # Done; return the response. return response - def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Imports data into a Dataset. Args: @@ -882,8 +909,10 @@ def import_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ImportDataRequest. @@ -907,18 +936,11 @@ def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -931,15 +953,16 @@ def import_data(self, # Done; return the response. return response - def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Exports data from a Dataset. Args: @@ -982,8 +1005,10 @@ def export_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ExportDataRequest. @@ -1007,18 +1032,11 @@ def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1031,14 +1049,15 @@ def export_data(self, # Done; return the response. return response - def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -1074,8 +1093,10 @@ def list_data_items(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDataItemsRequest. @@ -1097,39 +1118,30 @@ def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -1163,8 +1175,10 @@ def get_annotation_spec(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetAnnotationSpecRequest. @@ -1186,30 +1200,24 @@ def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1246,8 +1254,10 @@ def list_annotations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListAnnotationsRequest. @@ -1269,47 +1279,30 @@ def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceClient', -) +__all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py index 9f9b80b9a4..10653cbf25 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,8 +81,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -91,17 +91,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -110,56 +112,35 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_dataset: gapic_v1.method.wrap_method( - self.create_dataset, - default_timeout=5.0, - client_info=client_info, + self.create_dataset, default_timeout=5.0, client_info=client_info, ), self.get_dataset: gapic_v1.method.wrap_method( - self.get_dataset, - default_timeout=5.0, - client_info=client_info, + self.get_dataset, default_timeout=5.0, client_info=client_info, ), self.update_dataset: gapic_v1.method.wrap_method( - self.update_dataset, - default_timeout=5.0, - client_info=client_info, + self.update_dataset, default_timeout=5.0, client_info=client_info, ), self.list_datasets: gapic_v1.method.wrap_method( - self.list_datasets, - default_timeout=5.0, - client_info=client_info, + self.list_datasets, default_timeout=5.0, client_info=client_info, ), self.delete_dataset: gapic_v1.method.wrap_method( - self.delete_dataset, - default_timeout=5.0, - client_info=client_info, + self.delete_dataset, default_timeout=5.0, client_info=client_info, ), self.import_data: gapic_v1.method.wrap_method( - self.import_data, - default_timeout=5.0, - client_info=client_info, + self.import_data, default_timeout=5.0, client_info=client_info, ), self.export_data: gapic_v1.method.wrap_method( - self.export_data, - default_timeout=5.0, - client_info=client_info, + self.export_data, default_timeout=5.0, client_info=client_info, ), self.list_data_items: gapic_v1.method.wrap_method( - self.list_data_items, - default_timeout=5.0, - client_info=client_info, + self.list_data_items, default_timeout=5.0, client_info=client_info, ), self.get_annotation_spec: gapic_v1.method.wrap_method( - self.get_annotation_spec, - default_timeout=5.0, - client_info=client_info, + self.get_annotation_spec, default_timeout=5.0, client_info=client_info, ), self.list_annotations: gapic_v1.method.wrap_method( - self.list_annotations, - default_timeout=5.0, - client_info=client_info, + self.list_annotations, default_timeout=5.0, client_info=client_info, ), - } @property @@ -168,96 +149,106 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_dataset(self) -> typing.Callable[ - [dataset_service.CreateDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_dataset( + self, + ) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_dataset(self) -> typing.Callable[ - [dataset_service.GetDatasetRequest], - typing.Union[ - dataset.Dataset, - typing.Awaitable[dataset.Dataset] - ]]: + def get_dataset( + self, + ) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], + ]: raise NotImplementedError() @property - def update_dataset(self) -> typing.Callable[ - [dataset_service.UpdateDatasetRequest], - typing.Union[ - gca_dataset.Dataset, - typing.Awaitable[gca_dataset.Dataset] - ]]: + def update_dataset( + self, + ) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], + ]: raise NotImplementedError() @property - def list_datasets(self) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], - typing.Union[ - dataset_service.ListDatasetsResponse, - typing.Awaitable[dataset_service.ListDatasetsResponse] - ]]: + def list_datasets( + self, + ) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse], + ], + ]: raise NotImplementedError() @property - def delete_dataset(self) -> typing.Callable[ - [dataset_service.DeleteDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_dataset( + self, + ) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def import_data(self) -> typing.Callable[ - [dataset_service.ImportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def import_data( + self, + ) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_data(self) -> typing.Callable[ - [dataset_service.ExportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_data( + self, + ) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def list_data_items(self) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], - typing.Union[ - dataset_service.ListDataItemsResponse, - typing.Awaitable[dataset_service.ListDataItemsResponse] - ]]: + def list_data_items( + self, + ) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse], + ], + ]: raise NotImplementedError() @property - def get_annotation_spec(self) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], - typing.Union[ - annotation_spec.AnnotationSpec, - typing.Awaitable[annotation_spec.AnnotationSpec] - ]]: + def get_annotation_spec( + self, + ) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec], + ], + ]: raise NotImplementedError() @property - def list_annotations(self) -> typing.Callable[ - [dataset_service.ListAnnotationsRequest], - typing.Union[ - dataset_service.ListAnnotationsResponse, - typing.Awaitable[dataset_service.ListAnnotationsResponse] - ]]: + def list_annotations( + self, + ) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'DatasetServiceTransport', -) +__all__ = ("DatasetServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 9be4771620..c5a52fd3ed 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -56,13 +56,14 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry['grpc'] = EndpointServiceGrpcTransport - _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[EndpointServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry["grpc"] = EndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -113,7 +114,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -148,9 +149,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: EndpointServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -165,88 +165,104 @@ def transport(self) -> EndpointServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -290,7 +306,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -300,7 +318,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -312,7 +332,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -324,8 +346,10 @@ def __init__(self, *, if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -344,15 +368,16 @@ def __init__(self, *, client_info=client_info, ) - def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates an Endpoint. Args: @@ -392,8 +417,10 @@ def create_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.CreateEndpointRequest. @@ -417,18 +444,11 @@ def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -441,14 +461,15 @@ def create_endpoint(self, # Done; return the response. return response - def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -481,8 +502,10 @@ def get_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.GetEndpointRequest. @@ -504,30 +527,24 @@ def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -563,8 +580,10 @@ def list_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.ListEndpointsRequest. @@ -586,40 +605,31 @@ def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -659,8 +669,10 @@ def update_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UpdateEndpointRequest. @@ -684,30 +696,26 @@ def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes an Endpoint. Args: @@ -753,8 +761,10 @@ def delete_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeleteEndpointRequest. @@ -776,18 +786,11 @@ def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -800,16 +803,19 @@ def delete_endpoint(self, # Done; return the response. return response - def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -878,8 +884,10 @@ def deploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeployModelRequest. @@ -905,18 +913,11 @@ def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -929,16 +930,19 @@ def deploy_model(self, # Done; return the response. return response - def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -998,8 +1002,10 @@ def undeploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UndeployModelRequest. @@ -1025,18 +1031,11 @@ def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1050,21 +1049,14 @@ def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceClient', -) +__all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py index 65e049d43f..054d6c9b01 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,29 +35,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -80,8 +80,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -90,17 +90,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -109,41 +111,26 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_endpoint: gapic_v1.method.wrap_method( - self.create_endpoint, - default_timeout=5.0, - client_info=client_info, + self.create_endpoint, default_timeout=5.0, client_info=client_info, ), self.get_endpoint: gapic_v1.method.wrap_method( - self.get_endpoint, - default_timeout=5.0, - client_info=client_info, + self.get_endpoint, default_timeout=5.0, client_info=client_info, ), self.list_endpoints: gapic_v1.method.wrap_method( - self.list_endpoints, - default_timeout=5.0, - client_info=client_info, + self.list_endpoints, default_timeout=5.0, client_info=client_info, ), self.update_endpoint: gapic_v1.method.wrap_method( - self.update_endpoint, - default_timeout=5.0, - client_info=client_info, + self.update_endpoint, default_timeout=5.0, client_info=client_info, ), self.delete_endpoint: gapic_v1.method.wrap_method( - self.delete_endpoint, - default_timeout=5.0, - client_info=client_info, + self.delete_endpoint, default_timeout=5.0, client_info=client_info, ), self.deploy_model: gapic_v1.method.wrap_method( - self.deploy_model, - default_timeout=5.0, - client_info=client_info, + self.deploy_model, default_timeout=5.0, client_info=client_info, ), self.undeploy_model: gapic_v1.method.wrap_method( - self.undeploy_model, - default_timeout=5.0, - client_info=client_info, + self.undeploy_model, default_timeout=5.0, client_info=client_info, ), - } @property @@ -152,69 +139,70 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_endpoint(self) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_endpoint(self) -> typing.Callable[ - [endpoint_service.GetEndpointRequest], - typing.Union[ - endpoint.Endpoint, - typing.Awaitable[endpoint.Endpoint] - ]]: + def get_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def list_endpoints(self) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], - typing.Union[ - endpoint_service.ListEndpointsResponse, - typing.Awaitable[endpoint_service.ListEndpointsResponse] - ]]: + def list_endpoints( + self, + ) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse], + ], + ]: raise NotImplementedError() @property - def update_endpoint(self) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], - typing.Union[ - gca_endpoint.Endpoint, - typing.Awaitable[gca_endpoint.Endpoint] - ]]: + def update_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def delete_endpoint(self) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def deploy_model(self) -> typing.Callable[ - [endpoint_service.DeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def deploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def undeploy_model(self) -> typing.Callable[ - [endpoint_service.UndeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def undeploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'EndpointServiceTransport', -) +__all__ = ("EndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index a3cc318097..d5332ddc61 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -23,20 +23,22 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import completion_stats from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job @@ -44,7 +46,9 @@ from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import job_service from google.cloud.aiplatform_v1.types import job_state from google.cloud.aiplatform_v1.types import machine_resources @@ -69,13 +73,12 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry['grpc'] = JobServiceGrpcTransport - _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[JobServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -126,7 +129,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -161,9 +164,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: JobServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -178,143 +180,194 @@ def transport(self) -> JobServiceTransport: return self._transport @staticmethod - def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: + def batch_prediction_job_path( + project: str, location: str, batch_prediction_job: str, + ) -> str: """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, + location=location, + batch_prediction_job=batch_prediction_job, + ) @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: + def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: """Parse a batch_prediction_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def custom_job_path(project: str,location: str,custom_job: str,) -> str: + def custom_job_path(project: str, location: str, custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str,str]: + def parse_custom_job_path(path: str) -> Dict[str, str]: """Parse a custom_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: + def data_labeling_job_path( + project: str, location: str, data_labeling_job: str, + ) -> str: """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str,str]: + def parse_data_labeling_job_path(path: str) -> Dict[str, str]: """Parse a data_labeling_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: + def hyperparameter_tuning_job_path( + project: str, location: str, hyperparameter_tuning_job: str, + ) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str,location: str,study: str,trial: str,) -> str: + def trial_path(project: str, location: str, study: str, trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) @staticmethod - def parse_trial_path(path: str) -> Dict[str,str]: + def parse_trial_path(path: str) -> Dict[str, str]: """Parse a trial path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -358,7 +411,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -368,7 +423,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -380,7 +437,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -392,8 +451,10 @@ def __init__(self, *, if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -412,15 +473,16 @@ def __init__(self, *, client_info=client_info, ) - def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -465,8 +527,10 @@ def create_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateCustomJobRequest. @@ -490,30 +554,24 @@ def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -551,8 +609,10 @@ def get_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetCustomJobRequest. @@ -574,30 +634,24 @@ def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -633,8 +687,10 @@ def list_custom_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListCustomJobsRequest. @@ -656,39 +712,30 @@ def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a CustomJob. Args: @@ -734,8 +781,10 @@ def delete_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteCustomJobRequest. @@ -757,18 +806,11 @@ def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -781,14 +823,15 @@ def delete_custom_job(self, # Done; return the response. return response - def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -826,8 +869,10 @@ def cancel_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelCustomJobRequest. @@ -849,28 +894,24 @@ def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -910,8 +951,10 @@ def create_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateDataLabelingJobRequest. @@ -935,30 +978,24 @@ def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -992,8 +1029,10 @@ def get_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetDataLabelingJobRequest. @@ -1015,30 +1054,24 @@ def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -1073,8 +1106,10 @@ def list_data_labeling_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListDataLabelingJobsRequest. @@ -1096,39 +1131,30 @@ def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1175,8 +1201,10 @@ def delete_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteDataLabelingJobRequest. @@ -1198,18 +1226,11 @@ def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1222,14 +1243,15 @@ def delete_data_labeling_job(self, # Done; return the response. return response - def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1257,8 +1279,10 @@ def cancel_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelDataLabelingJobRequest. @@ -1280,28 +1304,24 @@ def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1343,8 +1363,10 @@ def create_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateHyperparameterTuningJobRequest. @@ -1363,35 +1385,31 @@ def create_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1427,8 +1445,10 @@ def get_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetHyperparameterTuningJobRequest. @@ -1445,35 +1465,31 @@ def get_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1509,8 +1525,10 @@ def list_hyperparameter_tuning_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListHyperparameterTuningJobsRequest. @@ -1527,44 +1545,37 @@ def list_hyperparameter_tuning_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1611,8 +1622,10 @@ def delete_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteHyperparameterTuningJobRequest. @@ -1629,23 +1642,18 @@ def delete_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1658,14 +1666,15 @@ def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1706,8 +1715,10 @@ def cancel_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelHyperparameterTuningJobRequest. @@ -1724,33 +1735,31 @@ def cancel_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1795,8 +1804,10 @@ def create_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateBatchPredictionJobRequest. @@ -1815,35 +1826,31 @@ def create_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1881,8 +1888,10 @@ def get_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetBatchPredictionJobRequest. @@ -1904,30 +1913,24 @@ def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1963,8 +1966,10 @@ def list_batch_prediction_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListBatchPredictionJobsRequest. @@ -1981,44 +1986,37 @@ def list_batch_prediction_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2066,8 +2064,10 @@ def delete_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteBatchPredictionJobRequest. @@ -2084,23 +2084,18 @@ def delete_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -2113,14 +2108,15 @@ def delete_batch_prediction_job(self, # Done; return the response. return response - def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -2159,8 +2155,10 @@ def cancel_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelBatchPredictionJobRequest. @@ -2177,40 +2175,30 @@ def cancel_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceClient', -) +__all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1/services/job_service/transports/base.py index 0292f60059..5cddf58749 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/base.py @@ -21,19 +21,23 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job -from google.cloud.aiplatform_v1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1.types import data_labeling_job from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -42,29 +46,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -87,8 +91,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -97,17 +101,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -116,29 +122,19 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_custom_job: gapic_v1.method.wrap_method( - self.create_custom_job, - default_timeout=5.0, - client_info=client_info, + self.create_custom_job, default_timeout=5.0, client_info=client_info, ), self.get_custom_job: gapic_v1.method.wrap_method( - self.get_custom_job, - default_timeout=5.0, - client_info=client_info, + self.get_custom_job, default_timeout=5.0, client_info=client_info, ), self.list_custom_jobs: gapic_v1.method.wrap_method( - self.list_custom_jobs, - default_timeout=5.0, - client_info=client_info, + self.list_custom_jobs, default_timeout=5.0, client_info=client_info, ), self.delete_custom_job: gapic_v1.method.wrap_method( - self.delete_custom_job, - default_timeout=5.0, - client_info=client_info, + self.delete_custom_job, default_timeout=5.0, client_info=client_info, ), self.cancel_custom_job: gapic_v1.method.wrap_method( - self.cancel_custom_job, - default_timeout=5.0, - client_info=client_info, + self.cancel_custom_job, default_timeout=5.0, client_info=client_info, ), self.create_data_labeling_job: gapic_v1.method.wrap_method( self.create_data_labeling_job, @@ -215,7 +211,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), - } @property @@ -224,186 +219,216 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_custom_job(self) -> typing.Callable[ - [job_service.CreateCustomJobRequest], - typing.Union[ - gca_custom_job.CustomJob, - typing.Awaitable[gca_custom_job.CustomJob] - ]]: + def create_custom_job( + self, + ) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] + ], + ]: raise NotImplementedError() @property - def get_custom_job(self) -> typing.Callable[ - [job_service.GetCustomJobRequest], - typing.Union[ - custom_job.CustomJob, - typing.Awaitable[custom_job.CustomJob] - ]]: + def get_custom_job( + self, + ) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], + ]: raise NotImplementedError() @property - def list_custom_jobs(self) -> typing.Callable[ - [job_service.ListCustomJobsRequest], - typing.Union[ - job_service.ListCustomJobsResponse, - typing.Awaitable[job_service.ListCustomJobsResponse] - ]]: + def list_custom_jobs( + self, + ) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_custom_job(self) -> typing.Callable[ - [job_service.DeleteCustomJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_custom_job( + self, + ) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_custom_job(self) -> typing.Callable[ - [job_service.CancelCustomJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_custom_job( + self, + ) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_data_labeling_job(self) -> typing.Callable[ - [job_service.CreateDataLabelingJobRequest], - typing.Union[ - gca_data_labeling_job.DataLabelingJob, - typing.Awaitable[gca_data_labeling_job.DataLabelingJob] - ]]: + def create_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def get_data_labeling_job(self) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], - typing.Union[ - data_labeling_job.DataLabelingJob, - typing.Awaitable[data_labeling_job.DataLabelingJob] - ]]: + def get_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def list_data_labeling_jobs(self) -> typing.Callable[ - [job_service.ListDataLabelingJobsRequest], - typing.Union[ - job_service.ListDataLabelingJobsResponse, - typing.Awaitable[job_service.ListDataLabelingJobsResponse] - ]]: + def list_data_labeling_jobs( + self, + ) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_data_labeling_job(self) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_data_labeling_job(self) -> typing.Callable[ - [job_service.CancelDataLabelingJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - typing.Union[ - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def create_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def get_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.GetHyperparameterTuningJobRequest], - typing.Union[ - hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def get_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def list_hyperparameter_tuning_jobs(self) -> typing.Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - typing.Union[ - job_service.ListHyperparameterTuningJobsResponse, - typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse] - ]]: + def list_hyperparameter_tuning_jobs( + self, + ) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_batch_prediction_job(self) -> typing.Callable[ - [job_service.CreateBatchPredictionJobRequest], - typing.Union[ - gca_batch_prediction_job.BatchPredictionJob, - typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob] - ]]: + def create_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def get_batch_prediction_job(self) -> typing.Callable[ - [job_service.GetBatchPredictionJobRequest], - typing.Union[ - batch_prediction_job.BatchPredictionJob, - typing.Awaitable[batch_prediction_job.BatchPredictionJob] - ]]: + def get_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def list_batch_prediction_jobs(self) -> typing.Callable[ - [job_service.ListBatchPredictionJobsRequest], - typing.Union[ - job_service.ListBatchPredictionJobsResponse, - typing.Awaitable[job_service.ListBatchPredictionJobsResponse] - ]]: + def list_batch_prediction_jobs( + self, + ) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_batch_prediction_job(self) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_batch_prediction_job(self) -> typing.Callable[ - [job_service.CancelBatchPredictionJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() -__all__ = ( - 'JobServiceTransport', -) +__all__ = ("JobServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 9f505b26b2..0d6e0fdbd6 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -50,13 +50,14 @@ class MigrationServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] - _transport_registry['grpc'] = MigrationServiceGrpcTransport - _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[MigrationServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry["grpc"] = MigrationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: """Return an appropriate transport class. Args: @@ -110,7 +111,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -145,9 +146,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MigrationServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -162,143 +162,183 @@ def transport(self) -> MigrationServiceTransport: return self._transport @staticmethod - def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: + def annotated_dataset_path( + project: str, dataset: str, annotated_dataset: str, + ) -> str: """Return a fully-qualified annotated_dataset string.""" - return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) @staticmethod - def parse_annotated_dataset_path(path: str) -> Dict[str,str]: + def parse_annotated_dataset_path(path: str) -> Dict[str, str]: """Parse a annotated_dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def version_path(project: str,model: str,version: str,) -> str: + def version_path(project: str, model: str, version: str,) -> str: """Return a fully-qualified version string.""" - return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + return "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) @staticmethod - def parse_version_path(path: str) -> Dict[str,str]: + def parse_version_path(path: str) -> Dict[str, str]: """Parse a version path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -342,7 +382,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -352,7 +394,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -364,7 +408,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -376,8 +422,10 @@ def __init__(self, *, if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -396,14 +444,15 @@ def __init__(self, *, client_info=client_info, ) - def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesPager: + def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -444,8 +493,10 @@ def search_migratable_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.SearchMigratableResourcesRequest. @@ -462,45 +513,40 @@ def search_migratable_resources(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] + rpc = self._transport._wrapped_methods[ + self._transport.search_migratable_resources + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchMigratableResourcesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -549,8 +595,10 @@ def batch_migrate_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.BatchMigrateResourcesRequest. @@ -574,18 +622,11 @@ def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation.from_gapic( @@ -599,21 +640,14 @@ def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceClient', -) +__all__ = ("MigrationServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index f0237a4359..e93d31639a 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,13 +60,12 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry['grpc'] = ModelServiceGrpcTransport - _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[ModelServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +116,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +151,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: ModelServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,121 +167,162 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: + def model_evaluation_path( + project: str, location: str, model: str, evaluation: str, + ) -> str: """Return a fully-qualified model_evaluation string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) @staticmethod - def parse_model_evaluation_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_path(path: str) -> Dict[str, str]: """Parse a model_evaluation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: + def model_evaluation_slice_path( + project: str, location: str, model: str, evaluation: str, slice: str, + ) -> str: """Return a fully-qualified model_evaluation_slice string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) @staticmethod - def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: """Parse a model_evaluation_slice path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -327,7 +366,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -337,7 +378,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -349,7 +392,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -361,8 +406,10 @@ def __init__(self, *, if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -381,15 +428,16 @@ def __init__(self, *, client_info=client_info, ) - def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -432,8 +480,10 @@ def upload_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UploadModelRequest. @@ -457,18 +507,11 @@ def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -481,14 +524,15 @@ def upload_model(self, # Done; return the response. return response - def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -518,8 +562,10 @@ def get_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -541,30 +587,24 @@ def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -600,8 +640,10 @@ def list_models(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -623,40 +665,31 @@ def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -694,8 +727,10 @@ def update_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateModelRequest. @@ -719,30 +754,26 @@ def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -790,8 +821,10 @@ def delete_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteModelRequest. @@ -813,18 +846,11 @@ def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -837,15 +863,16 @@ def delete_model(self, # Done; return the response. return response - def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -893,8 +920,10 @@ def export_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ExportModelRequest. @@ -918,18 +947,11 @@ def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -942,14 +964,15 @@ def export_model(self, # Done; return the response. return response - def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -985,8 +1008,10 @@ def get_model_evaluation(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationRequest. @@ -1008,30 +1033,24 @@ def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -1067,8 +1086,10 @@ def list_model_evaluations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationsRequest. @@ -1090,39 +1111,30 @@ def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -1158,8 +1170,10 @@ def get_model_evaluation_slice(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationSliceRequest. @@ -1176,35 +1190,31 @@ def get_model_evaluation_slice(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] + rpc = self._transport._wrapped_methods[ + self._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1241,8 +1251,10 @@ def list_model_evaluation_slices(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationSlicesRequest. @@ -1259,52 +1271,37 @@ def list_model_evaluation_slices(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] + rpc = self._transport._wrapped_methods[ + self._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceClient', -) +__all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1/services/model_service/transports/base.py index 262cb1c736..5252ac9c36 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -37,29 +37,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -82,8 +82,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -92,17 +92,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -111,39 +113,25 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.upload_model: gapic_v1.method.wrap_method( - self.upload_model, - default_timeout=5.0, - client_info=client_info, + self.upload_model, default_timeout=5.0, client_info=client_info, ), self.get_model: gapic_v1.method.wrap_method( - self.get_model, - default_timeout=5.0, - client_info=client_info, + self.get_model, default_timeout=5.0, client_info=client_info, ), self.list_models: gapic_v1.method.wrap_method( - self.list_models, - default_timeout=5.0, - client_info=client_info, + self.list_models, default_timeout=5.0, client_info=client_info, ), self.update_model: gapic_v1.method.wrap_method( - self.update_model, - default_timeout=5.0, - client_info=client_info, + self.update_model, default_timeout=5.0, client_info=client_info, ), self.delete_model: gapic_v1.method.wrap_method( - self.delete_model, - default_timeout=5.0, - client_info=client_info, + self.delete_model, default_timeout=5.0, client_info=client_info, ), self.export_model: gapic_v1.method.wrap_method( - self.export_model, - default_timeout=5.0, - client_info=client_info, + self.export_model, default_timeout=5.0, client_info=client_info, ), self.get_model_evaluation: gapic_v1.method.wrap_method( - self.get_model_evaluation, - default_timeout=5.0, - client_info=client_info, + self.get_model_evaluation, default_timeout=5.0, client_info=client_info, ), self.list_model_evaluations: gapic_v1.method.wrap_method( self.list_model_evaluations, @@ -160,7 +148,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), - } @property @@ -169,96 +156,109 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def upload_model(self) -> typing.Callable[ - [model_service.UploadModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def upload_model( + self, + ) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model(self) -> typing.Callable[ - [model_service.GetModelRequest], - typing.Union[ - model.Model, - typing.Awaitable[model.Model] - ]]: + def get_model( + self, + ) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[model.Model, typing.Awaitable[model.Model]], + ]: raise NotImplementedError() @property - def list_models(self) -> typing.Callable[ - [model_service.ListModelsRequest], - typing.Union[ - model_service.ListModelsResponse, - typing.Awaitable[model_service.ListModelsResponse] - ]]: + def list_models( + self, + ) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse], + ], + ]: raise NotImplementedError() @property - def update_model(self) -> typing.Callable[ - [model_service.UpdateModelRequest], - typing.Union[ - gca_model.Model, - typing.Awaitable[gca_model.Model] - ]]: + def update_model( + self, + ) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], + ]: raise NotImplementedError() @property - def delete_model(self) -> typing.Callable[ - [model_service.DeleteModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_model( + self, + ) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_model(self) -> typing.Callable[ - [model_service.ExportModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_model( + self, + ) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model_evaluation(self) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], - typing.Union[ - model_evaluation.ModelEvaluation, - typing.Awaitable[model_evaluation.ModelEvaluation] - ]]: + def get_model_evaluation( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation], + ], + ]: raise NotImplementedError() @property - def list_model_evaluations(self) -> typing.Callable[ - [model_service.ListModelEvaluationsRequest], - typing.Union[ - model_service.ListModelEvaluationsResponse, - typing.Awaitable[model_service.ListModelEvaluationsResponse] - ]]: + def list_model_evaluations( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse], + ], + ]: raise NotImplementedError() @property - def get_model_evaluation_slice(self) -> typing.Callable[ - [model_service.GetModelEvaluationSliceRequest], - typing.Union[ - model_evaluation_slice.ModelEvaluationSlice, - typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice] - ]]: + def get_model_evaluation_slice( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ], + ]: raise NotImplementedError() @property - def list_model_evaluation_slices(self) -> typing.Callable[ - [model_service.ListModelEvaluationSlicesRequest], - typing.Union[ - model_service.ListModelEvaluationSlicesResponse, - typing.Awaitable[model_service.ListModelEvaluationSlicesResponse] - ]]: + def list_model_evaluation_slices( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'ModelServiceTransport', -) +__all__ = ("ModelServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index 39d6f60f89..3b5e486ee5 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -59,13 +59,14 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry['grpc'] = PipelineServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PipelineServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry["grpc"] = PipelineServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -116,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -151,9 +152,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PipelineServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -168,99 +168,122 @@ def transport(self) -> PipelineServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -304,7 +327,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -314,7 +339,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -326,7 +353,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -338,8 +367,10 @@ def __init__(self, *, if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -358,15 +389,16 @@ def __init__(self, *, client_info=client_info, ) - def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -411,8 +443,10 @@ def create_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CreateTrainingPipelineRequest. @@ -436,30 +470,24 @@ def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -497,8 +525,10 @@ def get_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.GetTrainingPipelineRequest. @@ -520,30 +550,24 @@ def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -579,8 +603,10 @@ def list_training_pipelines(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.ListTrainingPipelinesRequest. @@ -602,39 +628,30 @@ def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -681,8 +698,10 @@ def delete_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.DeleteTrainingPipelineRequest. @@ -704,18 +723,11 @@ def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -728,14 +740,15 @@ def delete_training_pipeline(self, # Done; return the response. return response - def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -775,8 +788,10 @@ def cancel_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CancelTrainingPipelineRequest. @@ -798,35 +813,23 @@ def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceClient', -) +__all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py index ebba095d37..bee77f7896 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore @@ -31,29 +31,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -76,8 +76,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -86,17 +86,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -105,23 +107,21 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.predict: gapic_v1.method.wrap_method( - self.predict, - default_timeout=5.0, - client_info=client_info, + self.predict, default_timeout=5.0, client_info=client_info, ), - } @property - def predict(self) -> typing.Callable[ - [prediction_service.PredictRequest], - typing.Union[ - prediction_service.PredictResponse, - typing.Awaitable[prediction_service.PredictResponse] - ]]: + def predict( + self, + ) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'PredictionServiceTransport', -) +__all__ = ("PredictionServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index 968bf5dbd4..b319783aa0 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -54,13 +54,16 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport - _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport + _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +120,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +155,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: SpecialistPoolServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,77 +171,88 @@ def transport(self) -> SpecialistPoolServiceTransport: return self._transport @staticmethod - def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: + def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str,str]: + def parse_specialist_pool_path(path: str) -> Dict[str, str]: """Parse a specialist_pool path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -283,7 +296,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -293,7 +308,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -305,7 +322,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -317,8 +336,10 @@ def __init__(self, *, if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -337,15 +358,16 @@ def __init__(self, *, client_info=client_info, ) - def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -393,8 +415,10 @@ def create_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.CreateSpecialistPoolRequest. @@ -418,18 +442,11 @@ def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -442,14 +459,15 @@ def create_specialist_pool(self, # Done; return the response. return response - def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -492,8 +510,10 @@ def get_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.GetSpecialistPoolRequest. @@ -515,30 +535,24 @@ def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -574,8 +588,10 @@ def list_specialist_pools(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.ListSpecialistPoolsRequest. @@ -597,39 +613,30 @@ def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -676,8 +683,10 @@ def delete_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.DeleteSpecialistPoolRequest. @@ -699,18 +708,11 @@ def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -723,15 +725,16 @@ def delete_specialist_pool(self, # Done; return the response. return response - def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -778,8 +781,10 @@ def update_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.UpdateSpecialistPoolRequest. @@ -803,18 +808,13 @@ def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -828,21 +828,14 @@ def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceClient', -) +__all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py index e05bc7d77c..bf7e0209d7 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,29 +34,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -79,8 +79,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -89,17 +89,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -113,9 +115,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_specialist_pool: gapic_v1.method.wrap_method( - self.get_specialist_pool, - default_timeout=5.0, - client_info=client_info, + self.get_specialist_pool, default_timeout=5.0, client_info=client_info, ), self.list_specialist_pools: gapic_v1.method.wrap_method( self.list_specialist_pools, @@ -132,7 +132,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), - } @property @@ -141,51 +140,55 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - typing.Union[ - specialist_pool.SpecialistPool, - typing.Awaitable[specialist_pool.SpecialistPool] - ]]: + def get_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool], + ], + ]: raise NotImplementedError() @property - def list_specialist_pools(self) -> typing.Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - typing.Union[ - specialist_pool_service.ListSpecialistPoolsResponse, - typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] - ]]: + def list_specialist_pools( + self, + ) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ], + ]: raise NotImplementedError() @property - def delete_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def update_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def update_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'SpecialistPoolServiceTransport', -) +__all__ = ("SpecialistPoolServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 968477c4d7..3f605a0fcb 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -17,7 +17,9 @@ from .services.dataset_service import DatasetServiceClient from .services.endpoint_service import EndpointServiceClient -from .services.featurestore_online_serving_service import FeaturestoreOnlineServingServiceClient +from .services.featurestore_online_serving_service import ( + FeaturestoreOnlineServingServiceClient, +) from .services.featurestore_service import FeaturestoreServiceClient from .services.index_endpoint_service import IndexEndpointServiceClient from .services.index_service import IndexServiceClient @@ -282,11 +284,19 @@ from .types.model import ModelContainerSpec from .types.model import Port from .types.model import PredictSchemata -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringBigQueryTable +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringBigQueryTable, +) from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringJob -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringObjectiveConfig -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringObjectiveType -from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringScheduleConfig +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringObjectiveConfig, +) +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringObjectiveType, +) +from .types.model_deployment_monitoring_job import ( + ModelDeploymentMonitoringScheduleConfig, +) from .types.model_deployment_monitoring_job import ModelMonitoringStatsAnomalies from .types.model_evaluation import ModelEvaluation from .types.model_evaluation_slice import ModelEvaluationSlice @@ -373,359 +383,359 @@ __all__ = ( - 'AcceleratorType', - 'ActiveLearningConfig', - 'AddContextArtifactsAndExecutionsRequest', - 'AddContextArtifactsAndExecutionsResponse', - 'AddContextChildrenRequest', - 'AddContextChildrenResponse', - 'AddExecutionEventsRequest', - 'AddExecutionEventsResponse', - 'AddTrialMeasurementRequest', - 'Annotation', - 'AnnotationSpec', - 'Artifact', - 'Attribution', - 'AutomaticResources', - 'AutoscalingMetricSpec', - 'AvroSource', - 'BatchCreateFeaturesOperationMetadata', - 'BatchCreateFeaturesRequest', - 'BatchCreateFeaturesResponse', - 'BatchDedicatedResources', - 'BatchMigrateResourcesOperationMetadata', - 'BatchMigrateResourcesRequest', - 'BatchMigrateResourcesResponse', - 'BatchPredictionJob', - 'BatchReadFeatureValuesOperationMetadata', - 'BatchReadFeatureValuesRequest', - 'BatchReadFeatureValuesResponse', - 'BigQueryDestination', - 'BigQuerySource', - 'BoolArray', - 'CancelBatchPredictionJobRequest', - 'CancelCustomJobRequest', - 'CancelDataLabelingJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CancelTrainingPipelineRequest', - 'CheckTrialEarlyStoppingStateMetatdata', - 'CheckTrialEarlyStoppingStateRequest', - 'CheckTrialEarlyStoppingStateResponse', - 'CompleteTrialRequest', - 'CompletionStats', - 'ContainerRegistryDestination', - 'ContainerSpec', - 'Context', - 'CreateArtifactRequest', - 'CreateBatchPredictionJobRequest', - 'CreateContextRequest', - 'CreateCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'CreateDatasetOperationMetadata', - 'CreateDatasetRequest', - 'CreateEndpointOperationMetadata', - 'CreateEndpointRequest', - 'CreateEntityTypeOperationMetadata', - 'CreateEntityTypeRequest', - 'CreateExecutionRequest', - 'CreateFeatureOperationMetadata', - 'CreateFeatureRequest', - 'CreateFeaturestoreOperationMetadata', - 'CreateFeaturestoreRequest', - 'CreateHyperparameterTuningJobRequest', - 'CreateIndexEndpointOperationMetadata', - 'CreateIndexEndpointRequest', - 'CreateIndexOperationMetadata', - 'CreateIndexRequest', - 'CreateMetadataSchemaRequest', - 'CreateMetadataStoreOperationMetadata', - 'CreateMetadataStoreRequest', - 'CreateModelDeploymentMonitoringJobRequest', - 'CreateSpecialistPoolOperationMetadata', - 'CreateSpecialistPoolRequest', - 'CreateStudyRequest', - 'CreateTrainingPipelineRequest', - 'CreateTrialRequest', - 'CsvDestination', - 'CsvSource', - 'CustomJob', - 'CustomJobSpec', - 'DataItem', - 'DataLabelingJob', - 'Dataset', - 'DatasetServiceClient', - 'DedicatedResources', - 'DeleteBatchPredictionJobRequest', - 'DeleteContextRequest', - 'DeleteCustomJobRequest', - 'DeleteDataLabelingJobRequest', - 'DeleteDatasetRequest', - 'DeleteEndpointRequest', - 'DeleteEntityTypeRequest', - 'DeleteFeatureRequest', - 'DeleteFeaturestoreRequest', - 'DeleteHyperparameterTuningJobRequest', - 'DeleteIndexEndpointRequest', - 'DeleteIndexRequest', - 'DeleteMetadataStoreOperationMetadata', - 'DeleteMetadataStoreRequest', - 'DeleteModelDeploymentMonitoringJobRequest', - 'DeleteModelRequest', - 'DeleteOperationMetadata', - 'DeleteSpecialistPoolRequest', - 'DeleteStudyRequest', - 'DeleteTrainingPipelineRequest', - 'DeleteTrialRequest', - 'DeployIndexOperationMetadata', - 'DeployIndexRequest', - 'DeployIndexResponse', - 'DeployModelOperationMetadata', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployedIndex', - 'DeployedIndexAuthConfig', - 'DeployedIndexRef', - 'DeployedModel', - 'DeployedModelRef', - 'DestinationFeatureSetting', - 'DiskSpec', - 'DoubleArray', - 'EncryptionSpec', - 'Endpoint', - 'EndpointServiceClient', - 'EntityType', - 'EnvVar', - 'Event', - 'Execution', - 'ExplainRequest', - 'ExplainResponse', - 'Explanation', - 'ExplanationMetadata', - 'ExplanationMetadataOverride', - 'ExplanationParameters', - 'ExplanationSpec', - 'ExplanationSpecOverride', - 'ExportDataConfig', - 'ExportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportModelOperationMetadata', - 'ExportModelRequest', - 'ExportModelResponse', - 'Feature', - 'FeatureNoiseSigma', - 'FeatureSelector', - 'FeatureStatsAnomaly', - 'FeatureValue', - 'FeatureValueDestination', - 'FeatureValueList', - 'Featurestore', - 'FeaturestoreMonitoringConfig', - 'FeaturestoreOnlineServingServiceClient', - 'FeaturestoreServiceClient', - 'FilterSplit', - 'FractionSplit', - 'GcsDestination', - 'GcsSource', - 'GenericOperationMetadata', - 'GetAnnotationSpecRequest', - 'GetArtifactRequest', - 'GetBatchPredictionJobRequest', - 'GetContextRequest', - 'GetCustomJobRequest', - 'GetDataLabelingJobRequest', - 'GetDatasetRequest', - 'GetEndpointRequest', - 'GetEntityTypeRequest', - 'GetExecutionRequest', - 'GetFeatureRequest', - 'GetFeaturestoreRequest', - 'GetHyperparameterTuningJobRequest', - 'GetIndexEndpointRequest', - 'GetIndexRequest', - 'GetMetadataSchemaRequest', - 'GetMetadataStoreRequest', - 'GetModelDeploymentMonitoringJobRequest', - 'GetModelEvaluationRequest', - 'GetModelEvaluationSliceRequest', - 'GetModelRequest', - 'GetSpecialistPoolRequest', - 'GetStudyRequest', - 'GetTrainingPipelineRequest', - 'GetTrialRequest', - 'HyperparameterTuningJob', - 'IdMatcher', - 'ImportDataConfig', - 'ImportDataOperationMetadata', - 'ImportDataRequest', - 'ImportDataResponse', - 'ImportFeatureValuesOperationMetadata', - 'ImportFeatureValuesRequest', - 'ImportFeatureValuesResponse', - 'Index', - 'IndexEndpoint', - 'IndexEndpointServiceClient', - 'IndexPrivateEndpoints', - 'IndexServiceClient', - 'InputDataConfig', - 'Int64Array', - 'IntegratedGradientsAttribution', - 'JobServiceClient', - 'JobState', - 'LineageSubgraph', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', - 'ListArtifactsRequest', - 'ListArtifactsResponse', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'ListContextsRequest', - 'ListContextsResponse', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'ListEntityTypesRequest', - 'ListEntityTypesResponse', - 'ListExecutionsRequest', - 'ListExecutionsResponse', - 'ListFeaturesRequest', - 'ListFeaturesResponse', - 'ListFeaturestoresRequest', - 'ListFeaturestoresResponse', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'ListIndexEndpointsRequest', - 'ListIndexEndpointsResponse', - 'ListIndexesRequest', - 'ListIndexesResponse', - 'ListMetadataSchemasRequest', - 'ListMetadataSchemasResponse', - 'ListMetadataStoresRequest', - 'ListMetadataStoresResponse', - 'ListModelDeploymentMonitoringJobsRequest', - 'ListModelDeploymentMonitoringJobsResponse', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'ListModelsRequest', - 'ListModelsResponse', - 'ListOptimalTrialsRequest', - 'ListOptimalTrialsResponse', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'ListStudiesRequest', - 'ListStudiesResponse', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'ListTrialsRequest', - 'ListTrialsResponse', - 'LookupStudyRequest', - 'MachineSpec', - 'ManualBatchTuningParameters', - 'Measurement', - 'MetadataSchema', - 'MetadataStore', - 'MigratableResource', - 'MigrateResourceRequest', - 'MigrateResourceResponse', - 'MigrationServiceClient', - 'Model', - 'ModelContainerSpec', - 'ModelDeploymentMonitoringBigQueryTable', - 'ModelDeploymentMonitoringJob', - 'ModelDeploymentMonitoringObjectiveConfig', - 'ModelDeploymentMonitoringObjectiveType', - 'ModelDeploymentMonitoringScheduleConfig', - 'ModelEvaluation', - 'ModelEvaluationSlice', - 'ModelExplanation', - 'ModelMonitoringAlertConfig', - 'ModelMonitoringObjectiveConfig', - 'ModelMonitoringStatsAnomalies', - 'ModelServiceClient', - 'NearestNeighborSearchOperationMetadata', - 'PauseModelDeploymentMonitoringJobRequest', - 'PipelineServiceClient', - 'PipelineState', - 'Port', - 'PredefinedSplit', - 'PredictRequest', - 'PredictResponse', - 'PredictSchemata', - 'PredictionServiceClient', - 'PythonPackageSpec', - 'QueryArtifactLineageSubgraphRequest', - 'QueryContextLineageSubgraphRequest', - 'QueryExecutionInputsAndOutputsRequest', - 'ReadFeatureValuesRequest', - 'ReadFeatureValuesResponse', - 'ReadSetting', - 'ResourcesConsumed', - 'ResumeModelDeploymentMonitoringJobRequest', - 'SampleConfig', - 'SampledShapleyAttribution', - 'SamplingStrategy', - 'Scheduling', - 'SearchFeaturesRequest', - 'SearchFeaturesResponse', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', - 'SmoothGradConfig', - 'SpecialistPool', - 'SpecialistPoolServiceClient', - 'StopTrialRequest', - 'StreamingReadFeatureValuesRequest', - 'StringArray', - 'Study', - 'StudySpec', - 'SuggestTrialsMetadata', - 'SuggestTrialsRequest', - 'SuggestTrialsResponse', - 'TFRecordDestination', - 'ThresholdConfig', - 'TimestampSplit', - 'TrainingConfig', - 'TrainingPipeline', - 'Trial', - 'UndeployIndexOperationMetadata', - 'UndeployIndexRequest', - 'UndeployIndexResponse', - 'UndeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UpdateArtifactRequest', - 'UpdateContextRequest', - 'UpdateDatasetRequest', - 'UpdateEndpointRequest', - 'UpdateEntityTypeRequest', - 'UpdateExecutionRequest', - 'UpdateFeatureRequest', - 'UpdateFeaturestoreOperationMetadata', - 'UpdateFeaturestoreRequest', - 'UpdateIndexEndpointRequest', - 'UpdateIndexOperationMetadata', - 'UpdateIndexRequest', - 'UpdateModelDeploymentMonitoringJobOperationMetadata', - 'UpdateModelDeploymentMonitoringJobRequest', - 'UpdateModelRequest', - 'UpdateSpecialistPoolOperationMetadata', - 'UpdateSpecialistPoolRequest', - 'UploadModelOperationMetadata', - 'UploadModelRequest', - 'UploadModelResponse', - 'UserActionReference', - 'VizierServiceClient', - 'WorkerPoolSpec', - 'XraiAttribution', -'MetadataServiceClient', + "AcceleratorType", + "ActiveLearningConfig", + "AddContextArtifactsAndExecutionsRequest", + "AddContextArtifactsAndExecutionsResponse", + "AddContextChildrenRequest", + "AddContextChildrenResponse", + "AddExecutionEventsRequest", + "AddExecutionEventsResponse", + "AddTrialMeasurementRequest", + "Annotation", + "AnnotationSpec", + "Artifact", + "Attribution", + "AutomaticResources", + "AutoscalingMetricSpec", + "AvroSource", + "BatchCreateFeaturesOperationMetadata", + "BatchCreateFeaturesRequest", + "BatchCreateFeaturesResponse", + "BatchDedicatedResources", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "BatchPredictionJob", + "BatchReadFeatureValuesOperationMetadata", + "BatchReadFeatureValuesRequest", + "BatchReadFeatureValuesResponse", + "BigQueryDestination", + "BigQuerySource", + "BoolArray", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CancelTrainingPipelineRequest", + "CheckTrialEarlyStoppingStateMetatdata", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CompleteTrialRequest", + "CompletionStats", + "ContainerRegistryDestination", + "ContainerSpec", + "Context", + "CreateArtifactRequest", + "CreateBatchPredictionJobRequest", + "CreateContextRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "CreateEntityTypeOperationMetadata", + "CreateEntityTypeRequest", + "CreateExecutionRequest", + "CreateFeatureOperationMetadata", + "CreateFeatureRequest", + "CreateFeaturestoreOperationMetadata", + "CreateFeaturestoreRequest", + "CreateHyperparameterTuningJobRequest", + "CreateIndexEndpointOperationMetadata", + "CreateIndexEndpointRequest", + "CreateIndexOperationMetadata", + "CreateIndexRequest", + "CreateMetadataSchemaRequest", + "CreateMetadataStoreOperationMetadata", + "CreateMetadataStoreRequest", + "CreateModelDeploymentMonitoringJobRequest", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "CreateStudyRequest", + "CreateTrainingPipelineRequest", + "CreateTrialRequest", + "CsvDestination", + "CsvSource", + "CustomJob", + "CustomJobSpec", + "DataItem", + "DataLabelingJob", + "Dataset", + "DatasetServiceClient", + "DedicatedResources", + "DeleteBatchPredictionJobRequest", + "DeleteContextRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteDatasetRequest", + "DeleteEndpointRequest", + "DeleteEntityTypeRequest", + "DeleteFeatureRequest", + "DeleteFeaturestoreRequest", + "DeleteHyperparameterTuningJobRequest", + "DeleteIndexEndpointRequest", + "DeleteIndexRequest", + "DeleteMetadataStoreOperationMetadata", + "DeleteMetadataStoreRequest", + "DeleteModelDeploymentMonitoringJobRequest", + "DeleteModelRequest", + "DeleteOperationMetadata", + "DeleteSpecialistPoolRequest", + "DeleteStudyRequest", + "DeleteTrainingPipelineRequest", + "DeleteTrialRequest", + "DeployIndexOperationMetadata", + "DeployIndexRequest", + "DeployIndexResponse", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "DeployedIndex", + "DeployedIndexAuthConfig", + "DeployedIndexRef", + "DeployedModel", + "DeployedModelRef", + "DestinationFeatureSetting", + "DiskSpec", + "DoubleArray", + "EncryptionSpec", + "Endpoint", + "EndpointServiceClient", + "EntityType", + "EnvVar", + "Event", + "Execution", + "ExplainRequest", + "ExplainResponse", + "Explanation", + "ExplanationMetadata", + "ExplanationMetadataOverride", + "ExplanationParameters", + "ExplanationSpec", + "ExplanationSpecOverride", + "ExportDataConfig", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "Feature", + "FeatureNoiseSigma", + "FeatureSelector", + "FeatureStatsAnomaly", + "FeatureValue", + "FeatureValueDestination", + "FeatureValueList", + "Featurestore", + "FeaturestoreMonitoringConfig", + "FeaturestoreOnlineServingServiceClient", + "FeaturestoreServiceClient", + "FilterSplit", + "FractionSplit", + "GcsDestination", + "GcsSource", + "GenericOperationMetadata", + "GetAnnotationSpecRequest", + "GetArtifactRequest", + "GetBatchPredictionJobRequest", + "GetContextRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetDatasetRequest", + "GetEndpointRequest", + "GetEntityTypeRequest", + "GetExecutionRequest", + "GetFeatureRequest", + "GetFeaturestoreRequest", + "GetHyperparameterTuningJobRequest", + "GetIndexEndpointRequest", + "GetIndexRequest", + "GetMetadataSchemaRequest", + "GetMetadataStoreRequest", + "GetModelDeploymentMonitoringJobRequest", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "GetSpecialistPoolRequest", + "GetStudyRequest", + "GetTrainingPipelineRequest", + "GetTrialRequest", + "HyperparameterTuningJob", + "IdMatcher", + "ImportDataConfig", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "ImportFeatureValuesOperationMetadata", + "ImportFeatureValuesRequest", + "ImportFeatureValuesResponse", + "Index", + "IndexEndpoint", + "IndexEndpointServiceClient", + "IndexPrivateEndpoints", + "IndexServiceClient", + "InputDataConfig", + "Int64Array", + "IntegratedGradientsAttribution", + "JobServiceClient", + "JobState", + "LineageSubgraph", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListArtifactsRequest", + "ListArtifactsResponse", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListContextsRequest", + "ListContextsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "ListEndpointsRequest", + "ListEndpointsResponse", + "ListEntityTypesRequest", + "ListEntityTypesResponse", + "ListExecutionsRequest", + "ListExecutionsResponse", + "ListFeaturesRequest", + "ListFeaturesResponse", + "ListFeaturestoresRequest", + "ListFeaturestoresResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "ListIndexEndpointsRequest", + "ListIndexEndpointsResponse", + "ListIndexesRequest", + "ListIndexesResponse", + "ListMetadataSchemasRequest", + "ListMetadataSchemasResponse", + "ListMetadataStoresRequest", + "ListMetadataStoresResponse", + "ListModelDeploymentMonitoringJobsRequest", + "ListModelDeploymentMonitoringJobsResponse", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "ListStudiesRequest", + "ListStudiesResponse", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "ListTrialsRequest", + "ListTrialsResponse", + "LookupStudyRequest", + "MachineSpec", + "ManualBatchTuningParameters", + "Measurement", + "MetadataSchema", + "MetadataStore", + "MigratableResource", + "MigrateResourceRequest", + "MigrateResourceResponse", + "MigrationServiceClient", + "Model", + "ModelContainerSpec", + "ModelDeploymentMonitoringBigQueryTable", + "ModelDeploymentMonitoringJob", + "ModelDeploymentMonitoringObjectiveConfig", + "ModelDeploymentMonitoringObjectiveType", + "ModelDeploymentMonitoringScheduleConfig", + "ModelEvaluation", + "ModelEvaluationSlice", + "ModelExplanation", + "ModelMonitoringAlertConfig", + "ModelMonitoringObjectiveConfig", + "ModelMonitoringStatsAnomalies", + "ModelServiceClient", + "NearestNeighborSearchOperationMetadata", + "PauseModelDeploymentMonitoringJobRequest", + "PipelineServiceClient", + "PipelineState", + "Port", + "PredefinedSplit", + "PredictRequest", + "PredictResponse", + "PredictSchemata", + "PredictionServiceClient", + "PythonPackageSpec", + "QueryArtifactLineageSubgraphRequest", + "QueryContextLineageSubgraphRequest", + "QueryExecutionInputsAndOutputsRequest", + "ReadFeatureValuesRequest", + "ReadFeatureValuesResponse", + "ReadSetting", + "ResourcesConsumed", + "ResumeModelDeploymentMonitoringJobRequest", + "SampleConfig", + "SampledShapleyAttribution", + "SamplingStrategy", + "Scheduling", + "SearchFeaturesRequest", + "SearchFeaturesResponse", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "SearchModelDeploymentMonitoringStatsAnomaliesRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesResponse", + "SmoothGradConfig", + "SpecialistPool", + "SpecialistPoolServiceClient", + "StopTrialRequest", + "StreamingReadFeatureValuesRequest", + "StringArray", + "Study", + "StudySpec", + "SuggestTrialsMetadata", + "SuggestTrialsRequest", + "SuggestTrialsResponse", + "TFRecordDestination", + "ThresholdConfig", + "TimestampSplit", + "TrainingConfig", + "TrainingPipeline", + "Trial", + "UndeployIndexOperationMetadata", + "UndeployIndexRequest", + "UndeployIndexResponse", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateArtifactRequest", + "UpdateContextRequest", + "UpdateDatasetRequest", + "UpdateEndpointRequest", + "UpdateEntityTypeRequest", + "UpdateExecutionRequest", + "UpdateFeatureRequest", + "UpdateFeaturestoreOperationMetadata", + "UpdateFeaturestoreRequest", + "UpdateIndexEndpointRequest", + "UpdateIndexOperationMetadata", + "UpdateIndexRequest", + "UpdateModelDeploymentMonitoringJobOperationMetadata", + "UpdateModelDeploymentMonitoringJobRequest", + "UpdateModelRequest", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "UserActionReference", + "VizierServiceClient", + "WorkerPoolSpec", + "XraiAttribution", + "MetadataServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 0dfe93b1eb..abd38c0b7e 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,13 +60,14 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry['grpc'] = DatasetServiceGrpcTransport - _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[DatasetServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry["grpc"] = DatasetServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +118,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +153,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: DatasetServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,110 +169,149 @@ def transport(self) -> DatasetServiceTransport: return self._transport @staticmethod - def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: + def annotation_path( + project: str, location: str, dataset: str, data_item: str, annotation: str, + ) -> str: """Return a fully-qualified annotation string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) @staticmethod - def parse_annotation_path(path: str) -> Dict[str,str]: + def parse_annotation_path(path: str) -> Dict[str, str]: """Parse a annotation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: + def annotation_spec_path( + project: str, location: str, dataset: str, annotation_spec: str, + ) -> str: """Return a fully-qualified annotation_spec string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) @staticmethod - def parse_annotation_spec_path(path: str) -> Dict[str,str]: + def parse_annotation_spec_path(path: str) -> Dict[str, str]: """Parse a annotation_spec path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: + def data_item_path( + project: str, location: str, dataset: str, data_item: str, + ) -> str: """Return a fully-qualified data_item string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) @staticmethod - def parse_data_item_path(path: str) -> Dict[str,str]: + def parse_data_item_path(path: str) -> Dict[str, str]: """Parse a data_item path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -316,7 +355,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -326,7 +367,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -338,7 +381,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -350,8 +395,10 @@ def __init__(self, *, if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -370,15 +417,16 @@ def __init__(self, *, client_info=client_info, ) - def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a Dataset. Args: @@ -419,8 +467,10 @@ def create_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.CreateDatasetRequest. @@ -444,18 +494,11 @@ def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -468,14 +511,15 @@ def create_dataset(self, # Done; return the response. return response - def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -507,8 +551,10 @@ def get_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetDatasetRequest. @@ -530,31 +576,25 @@ def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -599,8 +639,10 @@ def update_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.UpdateDatasetRequest. @@ -624,30 +666,26 @@ def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -682,8 +720,10 @@ def list_datasets(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDatasetsRequest. @@ -705,39 +745,30 @@ def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a Dataset. Args: @@ -783,8 +814,10 @@ def delete_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.DeleteDatasetRequest. @@ -806,18 +839,11 @@ def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -830,15 +856,16 @@ def delete_dataset(self, # Done; return the response. return response - def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Imports data into a Dataset. Args: @@ -882,8 +909,10 @@ def import_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ImportDataRequest. @@ -907,18 +936,11 @@ def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -931,15 +953,16 @@ def import_data(self, # Done; return the response. return response - def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Exports data from a Dataset. Args: @@ -982,8 +1005,10 @@ def export_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ExportDataRequest. @@ -1007,18 +1032,11 @@ def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1031,14 +1049,15 @@ def export_data(self, # Done; return the response. return response - def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -1074,8 +1093,10 @@ def list_data_items(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDataItemsRequest. @@ -1097,39 +1118,30 @@ def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -1162,8 +1174,10 @@ def get_annotation_spec(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetAnnotationSpecRequest. @@ -1185,30 +1199,24 @@ def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1244,8 +1252,10 @@ def list_annotations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListAnnotationsRequest. @@ -1267,47 +1277,30 @@ def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceClient', -) +__all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 21e209da37..3d6063d800 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -56,13 +56,14 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry['grpc'] = EndpointServiceGrpcTransport - _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[EndpointServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry["grpc"] = EndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -113,7 +114,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -148,9 +149,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: EndpointServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -165,88 +165,104 @@ def transport(self) -> EndpointServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -290,7 +306,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -300,7 +318,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -312,7 +332,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -324,8 +346,10 @@ def __init__(self, *, if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -344,15 +368,16 @@ def __init__(self, *, client_info=client_info, ) - def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates an Endpoint. Args: @@ -392,8 +417,10 @@ def create_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.CreateEndpointRequest. @@ -417,18 +444,11 @@ def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -441,14 +461,15 @@ def create_endpoint(self, # Done; return the response. return response - def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -481,8 +502,10 @@ def get_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.GetEndpointRequest. @@ -504,30 +527,24 @@ def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -563,8 +580,10 @@ def list_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.ListEndpointsRequest. @@ -586,40 +605,31 @@ def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -659,8 +669,10 @@ def update_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UpdateEndpointRequest. @@ -684,30 +696,26 @@ def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes an Endpoint. Args: @@ -753,8 +761,10 @@ def delete_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeleteEndpointRequest. @@ -776,18 +786,11 @@ def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -800,16 +803,19 @@ def delete_endpoint(self, # Done; return the response. return response - def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -878,8 +884,10 @@ def deploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeployModelRequest. @@ -905,18 +913,11 @@ def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -929,16 +930,19 @@ def deploy_model(self, # Done; return the response. return response - def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -998,8 +1002,10 @@ def undeploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UndeployModelRequest. @@ -1025,18 +1031,11 @@ def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1050,21 +1049,14 @@ def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceClient', -) +__all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py index d5da9ac80e..8fca4944ab 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import FeaturestoreOnlineServingServiceAsyncClient __all__ = ( - 'FeaturestoreOnlineServingServiceClient', - 'FeaturestoreOnlineServingServiceAsyncClient', + "FeaturestoreOnlineServingServiceClient", + "FeaturestoreOnlineServingServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py index adb54190b0..8e2d8bec62 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py @@ -21,17 +21,22 @@ from typing import Dict, AsyncIterable, Awaitable, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import featurestore_online_service -from .transports.base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import FeaturestoreOnlineServingServiceGrpcAsyncIOTransport +from .transports.base import ( + FeaturestoreOnlineServingServiceTransport, + DEFAULT_CLIENT_INFO, +) +from .transports.grpc_asyncio import ( + FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, +) from .client import FeaturestoreOnlineServingServiceClient @@ -43,23 +48,47 @@ class FeaturestoreOnlineServingServiceAsyncClient: DEFAULT_ENDPOINT = FeaturestoreOnlineServingServiceClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = FeaturestoreOnlineServingServiceClient.DEFAULT_MTLS_ENDPOINT - entity_type_path = staticmethod(FeaturestoreOnlineServingServiceClient.entity_type_path) - parse_entity_type_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_entity_type_path) + entity_type_path = staticmethod( + FeaturestoreOnlineServingServiceClient.entity_type_path + ) + parse_entity_type_path = staticmethod( + FeaturestoreOnlineServingServiceClient.parse_entity_type_path + ) - common_billing_account_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + FeaturestoreOnlineServingServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + FeaturestoreOnlineServingServiceClient.parse_common_billing_account_path + ) - common_folder_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_folder_path) + common_folder_path = staticmethod( + FeaturestoreOnlineServingServiceClient.common_folder_path + ) + parse_common_folder_path = staticmethod( + FeaturestoreOnlineServingServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + FeaturestoreOnlineServingServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + FeaturestoreOnlineServingServiceClient.parse_common_organization_path + ) - common_project_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_project_path) - parse_common_project_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_project_path) + common_project_path = staticmethod( + FeaturestoreOnlineServingServiceClient.common_project_path + ) + parse_common_project_path = staticmethod( + FeaturestoreOnlineServingServiceClient.parse_common_project_path + ) - common_location_path = staticmethod(FeaturestoreOnlineServingServiceClient.common_location_path) - parse_common_location_path = staticmethod(FeaturestoreOnlineServingServiceClient.parse_common_location_path) + common_location_path = staticmethod( + FeaturestoreOnlineServingServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + FeaturestoreOnlineServingServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -102,14 +131,21 @@ def transport(self) -> FeaturestoreOnlineServingServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(FeaturestoreOnlineServingServiceClient).get_transport_class, type(FeaturestoreOnlineServingServiceClient)) + get_transport_class = functools.partial( + type(FeaturestoreOnlineServingServiceClient).get_transport_class, + type(FeaturestoreOnlineServingServiceClient), + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, FeaturestoreOnlineServingServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[ + str, FeaturestoreOnlineServingServiceTransport + ] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the featurestore online serving service client. Args: @@ -148,17 +184,17 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def read_feature_values(self, - request: featurestore_online_service.ReadFeatureValuesRequest = None, - *, - entity_type: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> featurestore_online_service.ReadFeatureValuesResponse: + async def read_feature_values( + self, + request: featurestore_online_service.ReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore_online_service.ReadFeatureValuesResponse: r"""Reads Feature values of a specific entity of an EntityType. For reading feature values of multiple entities of an EntityType, please use @@ -197,8 +233,10 @@ async def read_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_online_service.ReadFeatureValuesRequest(request) @@ -219,30 +257,28 @@ async def read_feature_values(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type', request.entity_type), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def streaming_read_feature_values(self, - request: featurestore_online_service.StreamingReadFeatureValuesRequest = None, - *, - entity_type: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[featurestore_online_service.ReadFeatureValuesResponse]]: + def streaming_read_feature_values( + self, + request: featurestore_online_service.StreamingReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Awaitable[ + AsyncIterable[featurestore_online_service.ReadFeatureValuesResponse] + ]: r"""Reads Feature values for multiple entities. Depending on their size, data for different entities may be broken up across multiple responses. @@ -280,8 +316,10 @@ def streaming_read_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_online_service.StreamingReadFeatureValuesRequest(request) @@ -302,38 +340,26 @@ def streaming_read_feature_values(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type', request.entity_type), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'FeaturestoreOnlineServingServiceAsyncClient', -) +__all__ = ("FeaturestoreOnlineServingServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py index 7a1b71a568..f54b7c109b 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py @@ -23,20 +23,25 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import featurestore_online_service -from .transports.base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import ( + FeaturestoreOnlineServingServiceTransport, + DEFAULT_CLIENT_INFO, +) from .transports.grpc import FeaturestoreOnlineServingServiceGrpcTransport -from .transports.grpc_asyncio import FeaturestoreOnlineServingServiceGrpcAsyncIOTransport +from .transports.grpc_asyncio import ( + FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, +) class FeaturestoreOnlineServingServiceClientMeta(type): @@ -46,13 +51,18 @@ class FeaturestoreOnlineServingServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]] - _transport_registry['grpc'] = FeaturestoreOnlineServingServiceGrpcTransport - _transport_registry['grpc_asyncio'] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[FeaturestoreOnlineServingServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]] + _transport_registry["grpc"] = FeaturestoreOnlineServingServiceGrpcTransport + _transport_registry[ + "grpc_asyncio" + ] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[FeaturestoreOnlineServingServiceTransport]: """Return an appropriate transport class. Args: @@ -71,7 +81,9 @@ def get_transport_class(cls, return next(iter(cls._transport_registry.values())) -class FeaturestoreOnlineServingServiceClient(metaclass=FeaturestoreOnlineServingServiceClientMeta): +class FeaturestoreOnlineServingServiceClient( + metaclass=FeaturestoreOnlineServingServiceClientMeta +): """A service for serving online feature values.""" @staticmethod @@ -103,7 +115,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -138,9 +150,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: FeaturestoreOnlineServingServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -155,77 +166,93 @@ def transport(self) -> FeaturestoreOnlineServingServiceTransport: return self._transport @staticmethod - def entity_type_path(project: str,location: str,featurestore: str,entity_type: str,) -> str: + def entity_type_path( + project: str, location: str, featurestore: str, entity_type: str, + ) -> str: """Return a fully-qualified entity_type string.""" - return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) + return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format( + project=project, + location=location, + featurestore=featurestore, + entity_type=entity_type, + ) @staticmethod - def parse_entity_type_path(path: str) -> Dict[str,str]: + def parse_entity_type_path(path: str) -> Dict[str, str]: """Parse a entity_type path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, FeaturestoreOnlineServingServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, FeaturestoreOnlineServingServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the featurestore online serving service client. Args: @@ -269,7 +296,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -279,7 +308,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -291,7 +322,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -303,8 +336,10 @@ def __init__(self, *, if isinstance(transport, FeaturestoreOnlineServingServiceTransport): # transport is a FeaturestoreOnlineServingServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -323,14 +358,15 @@ def __init__(self, *, client_info=client_info, ) - def read_feature_values(self, - request: featurestore_online_service.ReadFeatureValuesRequest = None, - *, - entity_type: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> featurestore_online_service.ReadFeatureValuesResponse: + def read_feature_values( + self, + request: featurestore_online_service.ReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore_online_service.ReadFeatureValuesResponse: r"""Reads Feature values of a specific entity of an EntityType. For reading feature values of multiple entities of an EntityType, please use @@ -369,14 +405,18 @@ def read_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_online_service.ReadFeatureValuesRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, featurestore_online_service.ReadFeatureValuesRequest): + if not isinstance( + request, featurestore_online_service.ReadFeatureValuesRequest + ): request = featurestore_online_service.ReadFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -392,30 +432,26 @@ def read_feature_values(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type', request.entity_type), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def streaming_read_feature_values(self, - request: featurestore_online_service.StreamingReadFeatureValuesRequest = None, - *, - entity_type: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> Iterable[featurestore_online_service.ReadFeatureValuesResponse]: + def streaming_read_feature_values( + self, + request: featurestore_online_service.StreamingReadFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[featurestore_online_service.ReadFeatureValuesResponse]: r"""Reads Feature values for multiple entities. Depending on their size, data for different entities may be broken up across multiple responses. @@ -453,15 +489,21 @@ def streaming_read_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_online_service.StreamingReadFeatureValuesRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, featurestore_online_service.StreamingReadFeatureValuesRequest): - request = featurestore_online_service.StreamingReadFeatureValuesRequest(request) + if not isinstance( + request, featurestore_online_service.StreamingReadFeatureValuesRequest + ): + request = featurestore_online_service.StreamingReadFeatureValuesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -471,43 +513,33 @@ def streaming_read_feature_values(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.streaming_read_feature_values] + rpc = self._transport._wrapped_methods[ + self._transport.streaming_read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type', request.entity_type), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'FeaturestoreOnlineServingServiceClient', -) +__all__ = ("FeaturestoreOnlineServingServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py index e3326680c7..fbb212cbc6 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/__init__.py @@ -24,12 +24,16 @@ # Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]] -_transport_registry['grpc'] = FeaturestoreOnlineServingServiceGrpcTransport -_transport_registry['grpc_asyncio'] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]] +_transport_registry["grpc"] = FeaturestoreOnlineServingServiceGrpcTransport +_transport_registry[ + "grpc_asyncio" +] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport __all__ = ( - 'FeaturestoreOnlineServingServiceTransport', - 'FeaturestoreOnlineServingServiceGrpcTransport', - 'FeaturestoreOnlineServingServiceGrpcAsyncIOTransport', + "FeaturestoreOnlineServingServiceTransport", + "FeaturestoreOnlineServingServiceGrpcTransport", + "FeaturestoreOnlineServingServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py index 8db9596f98..29ef041c28 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore @@ -31,29 +31,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class FeaturestoreOnlineServingServiceTransport(abc.ABC): """Abstract transport class for FeaturestoreOnlineServingService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -76,8 +76,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -86,17 +86,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -105,37 +107,38 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.read_feature_values: gapic_v1.method.wrap_method( - self.read_feature_values, - default_timeout=None, - client_info=client_info, + self.read_feature_values, default_timeout=None, client_info=client_info, ), self.streaming_read_feature_values: gapic_v1.method.wrap_method( self.streaming_read_feature_values, default_timeout=None, client_info=client_info, ), - } @property - def read_feature_values(self) -> typing.Callable[ - [featurestore_online_service.ReadFeatureValuesRequest], - typing.Union[ - featurestore_online_service.ReadFeatureValuesResponse, - typing.Awaitable[featurestore_online_service.ReadFeatureValuesResponse] - ]]: + def read_feature_values( + self, + ) -> typing.Callable[ + [featurestore_online_service.ReadFeatureValuesRequest], + typing.Union[ + featurestore_online_service.ReadFeatureValuesResponse, + typing.Awaitable[featurestore_online_service.ReadFeatureValuesResponse], + ], + ]: raise NotImplementedError() @property - def streaming_read_feature_values(self) -> typing.Callable[ - [featurestore_online_service.StreamingReadFeatureValuesRequest], - typing.Union[ - featurestore_online_service.ReadFeatureValuesResponse, - typing.Awaitable[featurestore_online_service.ReadFeatureValuesResponse] - ]]: + def streaming_read_feature_values( + self, + ) -> typing.Callable[ + [featurestore_online_service.StreamingReadFeatureValuesRequest], + typing.Union[ + featurestore_online_service.ReadFeatureValuesResponse, + typing.Awaitable[featurestore_online_service.ReadFeatureValuesResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'FeaturestoreOnlineServingServiceTransport', -) +__all__ = ("FeaturestoreOnlineServingServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py index 6ba3a31748..97b31e4acc 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py @@ -18,10 +18,10 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -31,7 +31,9 @@ from .base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO -class FeaturestoreOnlineServingServiceGrpcTransport(FeaturestoreOnlineServingServiceTransport): +class FeaturestoreOnlineServingServiceGrpcTransport( + FeaturestoreOnlineServingServiceTransport +): """gRPC backend transport for FeaturestoreOnlineServingService. A service for serving online feature values. @@ -43,21 +45,24 @@ class FeaturestoreOnlineServingServiceGrpcTransport(FeaturestoreOnlineServingSer It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -168,13 +173,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -207,7 +214,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -217,9 +224,12 @@ def grpc_channel(self) -> grpc.Channel: return self._grpc_channel @property - def read_feature_values(self) -> Callable[ - [featurestore_online_service.ReadFeatureValuesRequest], - featurestore_online_service.ReadFeatureValuesResponse]: + def read_feature_values( + self, + ) -> Callable[ + [featurestore_online_service.ReadFeatureValuesRequest], + featurestore_online_service.ReadFeatureValuesResponse, + ]: r"""Return a callable for the read feature values method over gRPC. Reads Feature values of a specific entity of an @@ -237,18 +247,21 @@ def read_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'read_feature_values' not in self._stubs: - self._stubs['read_feature_values'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/ReadFeatureValues', + if "read_feature_values" not in self._stubs: + self._stubs["read_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/ReadFeatureValues", request_serializer=featurestore_online_service.ReadFeatureValuesRequest.serialize, response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, ) - return self._stubs['read_feature_values'] + return self._stubs["read_feature_values"] @property - def streaming_read_feature_values(self) -> Callable[ - [featurestore_online_service.StreamingReadFeatureValuesRequest], - featurestore_online_service.ReadFeatureValuesResponse]: + def streaming_read_feature_values( + self, + ) -> Callable[ + [featurestore_online_service.StreamingReadFeatureValuesRequest], + featurestore_online_service.ReadFeatureValuesResponse, + ]: r"""Return a callable for the streaming read feature values method over gRPC. Reads Feature values for multiple entities. Depending @@ -265,15 +278,15 @@ def streaming_read_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'streaming_read_feature_values' not in self._stubs: - self._stubs['streaming_read_feature_values'] = self.grpc_channel.unary_stream( - '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/StreamingReadFeatureValues', + if "streaming_read_feature_values" not in self._stubs: + self._stubs[ + "streaming_read_feature_values" + ] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/StreamingReadFeatureValues", request_serializer=featurestore_online_service.StreamingReadFeatureValuesRequest.serialize, response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, ) - return self._stubs['streaming_read_feature_values'] + return self._stubs["streaming_read_feature_values"] -__all__ = ( - 'FeaturestoreOnlineServingServiceGrpcTransport', -) +__all__ = ("FeaturestoreOnlineServingServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py index bd03ab6626..5f92a32ab6 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py @@ -18,13 +18,13 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import featurestore_online_service @@ -33,7 +33,9 @@ from .grpc import FeaturestoreOnlineServingServiceGrpcTransport -class FeaturestoreOnlineServingServiceGrpcAsyncIOTransport(FeaturestoreOnlineServingServiceTransport): +class FeaturestoreOnlineServingServiceGrpcAsyncIOTransport( + FeaturestoreOnlineServingServiceTransport +): """gRPC AsyncIO backend transport for FeaturestoreOnlineServingService. A service for serving online feature values. @@ -50,13 +52,15 @@ class FeaturestoreOnlineServingServiceGrpcAsyncIOTransport(FeaturestoreOnlineSer _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -85,22 +89,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -222,9 +228,12 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def read_feature_values(self) -> Callable[ - [featurestore_online_service.ReadFeatureValuesRequest], - Awaitable[featurestore_online_service.ReadFeatureValuesResponse]]: + def read_feature_values( + self, + ) -> Callable[ + [featurestore_online_service.ReadFeatureValuesRequest], + Awaitable[featurestore_online_service.ReadFeatureValuesResponse], + ]: r"""Return a callable for the read feature values method over gRPC. Reads Feature values of a specific entity of an @@ -242,18 +251,21 @@ def read_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'read_feature_values' not in self._stubs: - self._stubs['read_feature_values'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/ReadFeatureValues', + if "read_feature_values" not in self._stubs: + self._stubs["read_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/ReadFeatureValues", request_serializer=featurestore_online_service.ReadFeatureValuesRequest.serialize, response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, ) - return self._stubs['read_feature_values'] + return self._stubs["read_feature_values"] @property - def streaming_read_feature_values(self) -> Callable[ - [featurestore_online_service.StreamingReadFeatureValuesRequest], - Awaitable[featurestore_online_service.ReadFeatureValuesResponse]]: + def streaming_read_feature_values( + self, + ) -> Callable[ + [featurestore_online_service.StreamingReadFeatureValuesRequest], + Awaitable[featurestore_online_service.ReadFeatureValuesResponse], + ]: r"""Return a callable for the streaming read feature values method over gRPC. Reads Feature values for multiple entities. Depending @@ -270,15 +282,15 @@ def streaming_read_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'streaming_read_feature_values' not in self._stubs: - self._stubs['streaming_read_feature_values'] = self.grpc_channel.unary_stream( - '/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/StreamingReadFeatureValues', + if "streaming_read_feature_values" not in self._stubs: + self._stubs[ + "streaming_read_feature_values" + ] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1beta1.FeaturestoreOnlineServingService/StreamingReadFeatureValues", request_serializer=featurestore_online_service.StreamingReadFeatureValuesRequest.serialize, response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize, ) - return self._stubs['streaming_read_feature_values'] + return self._stubs["streaming_read_feature_values"] -__all__ = ( - 'FeaturestoreOnlineServingServiceGrpcAsyncIOTransport', -) +__all__ = ("FeaturestoreOnlineServingServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py index e3d630a7cc..86c61ed8cf 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import FeaturestoreServiceAsyncClient __all__ = ( - 'FeaturestoreServiceClient', - 'FeaturestoreServiceAsyncClient', + "FeaturestoreServiceClient", + "FeaturestoreServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py index 99fdb689e3..f58bc25d1b 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -61,26 +61,44 @@ class FeaturestoreServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = FeaturestoreServiceClient.DEFAULT_MTLS_ENDPOINT entity_type_path = staticmethod(FeaturestoreServiceClient.entity_type_path) - parse_entity_type_path = staticmethod(FeaturestoreServiceClient.parse_entity_type_path) + parse_entity_type_path = staticmethod( + FeaturestoreServiceClient.parse_entity_type_path + ) feature_path = staticmethod(FeaturestoreServiceClient.feature_path) parse_feature_path = staticmethod(FeaturestoreServiceClient.parse_feature_path) featurestore_path = staticmethod(FeaturestoreServiceClient.featurestore_path) - parse_featurestore_path = staticmethod(FeaturestoreServiceClient.parse_featurestore_path) + parse_featurestore_path = staticmethod( + FeaturestoreServiceClient.parse_featurestore_path + ) - common_billing_account_path = staticmethod(FeaturestoreServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(FeaturestoreServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + FeaturestoreServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + FeaturestoreServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(FeaturestoreServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(FeaturestoreServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + FeaturestoreServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(FeaturestoreServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(FeaturestoreServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + FeaturestoreServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + FeaturestoreServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(FeaturestoreServiceClient.common_project_path) - parse_common_project_path = staticmethod(FeaturestoreServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + FeaturestoreServiceClient.parse_common_project_path + ) common_location_path = staticmethod(FeaturestoreServiceClient.common_location_path) - parse_common_location_path = staticmethod(FeaturestoreServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + FeaturestoreServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -123,14 +141,19 @@ def transport(self) -> FeaturestoreServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(FeaturestoreServiceClient).get_transport_class, type(FeaturestoreServiceClient)) + get_transport_class = functools.partial( + type(FeaturestoreServiceClient).get_transport_class, + type(FeaturestoreServiceClient), + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, FeaturestoreServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, FeaturestoreServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the featurestore service client. Args: @@ -169,18 +192,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_featurestore(self, - request: featurestore_service.CreateFeaturestoreRequest = None, - *, - parent: str = None, - featurestore: gca_featurestore.Featurestore = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_featurestore( + self, + request: featurestore_service.CreateFeaturestoreRequest = None, + *, + parent: str = None, + featurestore: gca_featurestore.Featurestore = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a new Featurestore in a given project and location. @@ -223,8 +246,10 @@ async def create_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.CreateFeaturestoreRequest(request) @@ -247,18 +272,11 @@ async def create_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -271,14 +289,15 @@ async def create_featurestore(self, # Done; return the response. return response - async def get_featurestore(self, - request: featurestore_service.GetFeaturestoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> featurestore.Featurestore: + async def get_featurestore( + self, + request: featurestore_service.GetFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore.Featurestore: r"""Gets details of a single Featurestore. Args: @@ -311,8 +330,10 @@ async def get_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.GetFeaturestoreRequest(request) @@ -333,30 +354,24 @@ async def get_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_featurestores(self, - request: featurestore_service.ListFeaturestoresRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListFeaturestoresAsyncPager: + async def list_featurestores( + self, + request: featurestore_service.ListFeaturestoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturestoresAsyncPager: r"""Lists Featurestores in a given project and location. Args: @@ -392,8 +407,10 @@ async def list_featurestores(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.ListFeaturestoresRequest(request) @@ -414,40 +431,31 @@ async def list_featurestores(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListFeaturestoresAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_featurestore(self, - request: featurestore_service.UpdateFeaturestoreRequest = None, - *, - featurestore: gca_featurestore.Featurestore = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_featurestore( + self, + request: featurestore_service.UpdateFeaturestoreRequest = None, + *, + featurestore: gca_featurestore.Featurestore = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates the parameters of a single Featurestore. Args: @@ -504,8 +512,10 @@ async def update_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.UpdateFeaturestoreRequest(request) @@ -528,18 +538,13 @@ async def update_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('featurestore.name', request.featurestore.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("featurestore.name", request.featurestore.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -552,14 +557,15 @@ async def update_featurestore(self, # Done; return the response. return response - async def delete_featurestore(self, - request: featurestore_service.DeleteFeaturestoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_featurestore( + self, + request: featurestore_service.DeleteFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a single Featurestore. The Featurestore must not contain any EntityTypes or ``force`` must be set to true for the request to succeed. @@ -607,8 +613,10 @@ async def delete_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.DeleteFeaturestoreRequest(request) @@ -629,18 +637,11 @@ async def delete_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -653,15 +654,16 @@ async def delete_featurestore(self, # Done; return the response. return response - async def create_entity_type(self, - request: featurestore_service.CreateEntityTypeRequest = None, - *, - parent: str = None, - entity_type: gca_entity_type.EntityType = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_entity_type( + self, + request: featurestore_service.CreateEntityTypeRequest = None, + *, + parent: str = None, + entity_type: gca_entity_type.EntityType = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a new EntityType in a given Featurestore. Args: @@ -703,8 +705,10 @@ async def create_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.CreateEntityTypeRequest(request) @@ -727,18 +731,11 @@ async def create_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -751,14 +748,15 @@ async def create_entity_type(self, # Done; return the response. return response - async def get_entity_type(self, - request: featurestore_service.GetEntityTypeRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> entity_type.EntityType: + async def get_entity_type( + self, + request: featurestore_service.GetEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> entity_type.EntityType: r"""Gets details of a single EntityType. Args: @@ -794,8 +792,10 @@ async def get_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.GetEntityTypeRequest(request) @@ -816,30 +816,24 @@ async def get_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_entity_types(self, - request: featurestore_service.ListEntityTypesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEntityTypesAsyncPager: + async def list_entity_types( + self, + request: featurestore_service.ListEntityTypesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEntityTypesAsyncPager: r"""Lists EntityTypes in a given Featurestore. Args: @@ -875,8 +869,10 @@ async def list_entity_types(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.ListEntityTypesRequest(request) @@ -897,40 +893,31 @@ async def list_entity_types(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListEntityTypesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_entity_type(self, - request: featurestore_service.UpdateEntityTypeRequest = None, - *, - entity_type: gca_entity_type.EntityType = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_entity_type.EntityType: + async def update_entity_type( + self, + request: featurestore_service.UpdateEntityTypeRequest = None, + *, + entity_type: gca_entity_type.EntityType = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_entity_type.EntityType: r"""Updates the parameters of a single EntityType. Args: @@ -987,8 +974,10 @@ async def update_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.UpdateEntityTypeRequest(request) @@ -1011,30 +1000,26 @@ async def update_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type.name', request.entity_type.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type.name", request.entity_type.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_entity_type(self, - request: featurestore_service.DeleteEntityTypeRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_entity_type( + self, + request: featurestore_service.DeleteEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a single EntityType. The EntityType must not have any Features or ``force`` must be set to true for the request to succeed. @@ -1082,8 +1067,10 @@ async def delete_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.DeleteEntityTypeRequest(request) @@ -1104,18 +1091,11 @@ async def delete_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1128,15 +1108,16 @@ async def delete_entity_type(self, # Done; return the response. return response - async def create_feature(self, - request: featurestore_service.CreateFeatureRequest = None, - *, - parent: str = None, - feature: gca_feature.Feature = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_feature( + self, + request: featurestore_service.CreateFeatureRequest = None, + *, + parent: str = None, + feature: gca_feature.Feature = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a new Feature in a given EntityType. Args: @@ -1177,8 +1158,10 @@ async def create_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.CreateFeatureRequest(request) @@ -1201,18 +1184,11 @@ async def create_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1225,15 +1201,16 @@ async def create_feature(self, # Done; return the response. return response - async def batch_create_features(self, - request: featurestore_service.BatchCreateFeaturesRequest = None, - *, - parent: str = None, - requests: Sequence[featurestore_service.CreateFeatureRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def batch_create_features( + self, + request: featurestore_service.BatchCreateFeaturesRequest = None, + *, + parent: str = None, + requests: Sequence[featurestore_service.CreateFeatureRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a batch of Features in a given EntityType. Args: @@ -1281,8 +1258,10 @@ async def batch_create_features(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.BatchCreateFeaturesRequest(request) @@ -1306,18 +1285,11 @@ async def batch_create_features(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1330,14 +1302,15 @@ async def batch_create_features(self, # Done; return the response. return response - async def get_feature(self, - request: featurestore_service.GetFeatureRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> feature.Feature: + async def get_feature( + self, + request: featurestore_service.GetFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> feature.Feature: r"""Gets details of a single Feature. Args: @@ -1372,8 +1345,10 @@ async def get_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.GetFeatureRequest(request) @@ -1394,30 +1369,24 @@ async def get_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_features(self, - request: featurestore_service.ListFeaturesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListFeaturesAsyncPager: + async def list_features( + self, + request: featurestore_service.ListFeaturesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturesAsyncPager: r"""Lists Features in a given EntityType. Args: @@ -1453,8 +1422,10 @@ async def list_features(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.ListFeaturesRequest(request) @@ -1475,40 +1446,31 @@ async def list_features(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListFeaturesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_feature(self, - request: featurestore_service.UpdateFeatureRequest = None, - *, - feature: gca_feature.Feature = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_feature.Feature: + async def update_feature( + self, + request: featurestore_service.UpdateFeatureRequest = None, + *, + feature: gca_feature.Feature = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_feature.Feature: r"""Updates the parameters of a single Feature. Args: @@ -1564,8 +1526,10 @@ async def update_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.UpdateFeatureRequest(request) @@ -1588,30 +1552,26 @@ async def update_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('feature.name', request.feature.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("feature.name", request.feature.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_feature(self, - request: featurestore_service.DeleteFeatureRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_feature( + self, + request: featurestore_service.DeleteFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a single Feature. Args: @@ -1657,8 +1617,10 @@ async def delete_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.DeleteFeatureRequest(request) @@ -1679,18 +1641,11 @@ async def delete_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1703,14 +1658,15 @@ async def delete_feature(self, # Done; return the response. return response - async def import_feature_values(self, - request: featurestore_service.ImportFeatureValuesRequest = None, - *, - entity_type: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def import_feature_values( + self, + request: featurestore_service.ImportFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Imports Feature values into the Featurestore from a source storage. The progress of the import is tracked by the returned @@ -1768,8 +1724,10 @@ async def import_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.ImportFeatureValuesRequest(request) @@ -1790,18 +1748,13 @@ async def import_feature_values(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type', request.entity_type), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1814,14 +1767,15 @@ async def import_feature_values(self, # Done; return the response. return response - async def batch_read_feature_values(self, - request: featurestore_service.BatchReadFeatureValuesRequest = None, - *, - featurestore: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def batch_read_feature_values( + self, + request: featurestore_service.BatchReadFeatureValuesRequest = None, + *, + featurestore: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Batch reads Feature values from a Featurestore. This API enables batch reading Feature values, where each read instance in the batch may read Feature values @@ -1863,8 +1817,10 @@ async def batch_read_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.BatchReadFeatureValuesRequest(request) @@ -1885,18 +1841,13 @@ async def batch_read_feature_values(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('featurestore', request.featurestore), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("featurestore", request.featurestore),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1909,14 +1860,15 @@ async def batch_read_feature_values(self, # Done; return the response. return response - async def search_features(self, - request: featurestore_service.SearchFeaturesRequest = None, - *, - location: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchFeaturesAsyncPager: + async def search_features( + self, + request: featurestore_service.SearchFeaturesRequest = None, + *, + location: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchFeaturesAsyncPager: r"""Searches Features matching a query in a given project. @@ -1953,8 +1905,10 @@ async def search_features(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([location]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = featurestore_service.SearchFeaturesRequest(request) @@ -1975,47 +1929,30 @@ async def search_features(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('location', request.location), - )), + gapic_v1.routing_header.to_grpc_metadata((("location", request.location),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.SearchFeaturesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'FeaturestoreServiceAsyncClient', -) +__all__ = ("FeaturestoreServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index 1bef3bb531..2b9991c9ba 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -61,13 +61,16 @@ class FeaturestoreServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreServiceTransport]] - _transport_registry['grpc'] = FeaturestoreServiceGrpcTransport - _transport_registry['grpc_asyncio'] = FeaturestoreServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[FeaturestoreServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[FeaturestoreServiceTransport]] + _transport_registry["grpc"] = FeaturestoreServiceGrpcTransport + _transport_registry["grpc_asyncio"] = FeaturestoreServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[FeaturestoreServiceTransport]: """Return an appropriate transport class. Args: @@ -120,7 +123,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -155,9 +158,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: FeaturestoreServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -172,99 +174,131 @@ def transport(self) -> FeaturestoreServiceTransport: return self._transport @staticmethod - def entity_type_path(project: str,location: str,featurestore: str,entity_type: str,) -> str: + def entity_type_path( + project: str, location: str, featurestore: str, entity_type: str, + ) -> str: """Return a fully-qualified entity_type string.""" - return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) + return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format( + project=project, + location=location, + featurestore=featurestore, + entity_type=entity_type, + ) @staticmethod - def parse_entity_type_path(path: str) -> Dict[str,str]: + def parse_entity_type_path(path: str) -> Dict[str, str]: """Parse a entity_type path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def feature_path(project: str,location: str,featurestore: str,entity_type: str,feature: str,) -> str: + def feature_path( + project: str, location: str, featurestore: str, entity_type: str, feature: str, + ) -> str: """Return a fully-qualified feature string.""" - return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, feature=feature, ) + return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}".format( + project=project, + location=location, + featurestore=featurestore, + entity_type=entity_type, + feature=feature, + ) @staticmethod - def parse_feature_path(path: str) -> Dict[str,str]: + def parse_feature_path(path: str) -> Dict[str, str]: """Parse a feature path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)/features/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)/features/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def featurestore_path(project: str,location: str,featurestore: str,) -> str: + def featurestore_path(project: str, location: str, featurestore: str,) -> str: """Return a fully-qualified featurestore string.""" - return "projects/{project}/locations/{location}/featurestores/{featurestore}".format(project=project, location=location, featurestore=featurestore, ) + return "projects/{project}/locations/{location}/featurestores/{featurestore}".format( + project=project, location=location, featurestore=featurestore, + ) @staticmethod - def parse_featurestore_path(path: str) -> Dict[str,str]: + def parse_featurestore_path(path: str) -> Dict[str, str]: """Parse a featurestore path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, FeaturestoreServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, FeaturestoreServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the featurestore service client. Args: @@ -308,7 +342,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -318,7 +354,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -330,7 +368,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -342,8 +382,10 @@ def __init__(self, *, if isinstance(transport, FeaturestoreServiceTransport): # transport is a FeaturestoreServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -362,15 +404,16 @@ def __init__(self, *, client_info=client_info, ) - def create_featurestore(self, - request: featurestore_service.CreateFeaturestoreRequest = None, - *, - parent: str = None, - featurestore: gca_featurestore.Featurestore = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_featurestore( + self, + request: featurestore_service.CreateFeaturestoreRequest = None, + *, + parent: str = None, + featurestore: gca_featurestore.Featurestore = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a new Featurestore in a given project and location. @@ -413,8 +456,10 @@ def create_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.CreateFeaturestoreRequest. @@ -438,18 +483,11 @@ def create_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -462,14 +500,15 @@ def create_featurestore(self, # Done; return the response. return response - def get_featurestore(self, - request: featurestore_service.GetFeaturestoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> featurestore.Featurestore: + def get_featurestore( + self, + request: featurestore_service.GetFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> featurestore.Featurestore: r"""Gets details of a single Featurestore. Args: @@ -502,8 +541,10 @@ def get_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.GetFeaturestoreRequest. @@ -525,30 +566,24 @@ def get_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_featurestores(self, - request: featurestore_service.ListFeaturestoresRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListFeaturestoresPager: + def list_featurestores( + self, + request: featurestore_service.ListFeaturestoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturestoresPager: r"""Lists Featurestores in a given project and location. Args: @@ -584,8 +619,10 @@ def list_featurestores(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.ListFeaturestoresRequest. @@ -607,40 +644,31 @@ def list_featurestores(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListFeaturestoresPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_featurestore(self, - request: featurestore_service.UpdateFeaturestoreRequest = None, - *, - featurestore: gca_featurestore.Featurestore = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def update_featurestore( + self, + request: featurestore_service.UpdateFeaturestoreRequest = None, + *, + featurestore: gca_featurestore.Featurestore = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Updates the parameters of a single Featurestore. Args: @@ -697,8 +725,10 @@ def update_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.UpdateFeaturestoreRequest. @@ -722,18 +752,13 @@ def update_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('featurestore.name', request.featurestore.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("featurestore.name", request.featurestore.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -746,14 +771,15 @@ def update_featurestore(self, # Done; return the response. return response - def delete_featurestore(self, - request: featurestore_service.DeleteFeaturestoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_featurestore( + self, + request: featurestore_service.DeleteFeaturestoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a single Featurestore. The Featurestore must not contain any EntityTypes or ``force`` must be set to true for the request to succeed. @@ -801,8 +827,10 @@ def delete_featurestore(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.DeleteFeaturestoreRequest. @@ -824,18 +852,11 @@ def delete_featurestore(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -848,15 +869,16 @@ def delete_featurestore(self, # Done; return the response. return response - def create_entity_type(self, - request: featurestore_service.CreateEntityTypeRequest = None, - *, - parent: str = None, - entity_type: gca_entity_type.EntityType = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_entity_type( + self, + request: featurestore_service.CreateEntityTypeRequest = None, + *, + parent: str = None, + entity_type: gca_entity_type.EntityType = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a new EntityType in a given Featurestore. Args: @@ -898,8 +920,10 @@ def create_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.CreateEntityTypeRequest. @@ -923,18 +947,11 @@ def create_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -947,14 +964,15 @@ def create_entity_type(self, # Done; return the response. return response - def get_entity_type(self, - request: featurestore_service.GetEntityTypeRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> entity_type.EntityType: + def get_entity_type( + self, + request: featurestore_service.GetEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> entity_type.EntityType: r"""Gets details of a single EntityType. Args: @@ -990,8 +1008,10 @@ def get_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.GetEntityTypeRequest. @@ -1013,30 +1033,24 @@ def get_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_entity_types(self, - request: featurestore_service.ListEntityTypesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEntityTypesPager: + def list_entity_types( + self, + request: featurestore_service.ListEntityTypesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEntityTypesPager: r"""Lists EntityTypes in a given Featurestore. Args: @@ -1072,8 +1086,10 @@ def list_entity_types(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.ListEntityTypesRequest. @@ -1095,40 +1111,31 @@ def list_entity_types(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEntityTypesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_entity_type(self, - request: featurestore_service.UpdateEntityTypeRequest = None, - *, - entity_type: gca_entity_type.EntityType = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_entity_type.EntityType: + def update_entity_type( + self, + request: featurestore_service.UpdateEntityTypeRequest = None, + *, + entity_type: gca_entity_type.EntityType = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_entity_type.EntityType: r"""Updates the parameters of a single EntityType. Args: @@ -1185,8 +1192,10 @@ def update_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.UpdateEntityTypeRequest. @@ -1210,30 +1219,26 @@ def update_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type.name', request.entity_type.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type.name", request.entity_type.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_entity_type(self, - request: featurestore_service.DeleteEntityTypeRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_entity_type( + self, + request: featurestore_service.DeleteEntityTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a single EntityType. The EntityType must not have any Features or ``force`` must be set to true for the request to succeed. @@ -1281,8 +1286,10 @@ def delete_entity_type(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.DeleteEntityTypeRequest. @@ -1304,18 +1311,11 @@ def delete_entity_type(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1328,15 +1328,16 @@ def delete_entity_type(self, # Done; return the response. return response - def create_feature(self, - request: featurestore_service.CreateFeatureRequest = None, - *, - parent: str = None, - feature: gca_feature.Feature = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_feature( + self, + request: featurestore_service.CreateFeatureRequest = None, + *, + parent: str = None, + feature: gca_feature.Feature = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a new Feature in a given EntityType. Args: @@ -1377,8 +1378,10 @@ def create_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.CreateFeatureRequest. @@ -1402,18 +1405,11 @@ def create_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1426,15 +1422,16 @@ def create_feature(self, # Done; return the response. return response - def batch_create_features(self, - request: featurestore_service.BatchCreateFeaturesRequest = None, - *, - parent: str = None, - requests: Sequence[featurestore_service.CreateFeatureRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def batch_create_features( + self, + request: featurestore_service.BatchCreateFeaturesRequest = None, + *, + parent: str = None, + requests: Sequence[featurestore_service.CreateFeatureRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a batch of Features in a given EntityType. Args: @@ -1482,8 +1479,10 @@ def batch_create_features(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.BatchCreateFeaturesRequest. @@ -1507,18 +1506,11 @@ def batch_create_features(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1531,14 +1523,15 @@ def batch_create_features(self, # Done; return the response. return response - def get_feature(self, - request: featurestore_service.GetFeatureRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> feature.Feature: + def get_feature( + self, + request: featurestore_service.GetFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> feature.Feature: r"""Gets details of a single Feature. Args: @@ -1573,8 +1566,10 @@ def get_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.GetFeatureRequest. @@ -1596,30 +1591,24 @@ def get_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_features(self, - request: featurestore_service.ListFeaturesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListFeaturesPager: + def list_features( + self, + request: featurestore_service.ListFeaturesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListFeaturesPager: r"""Lists Features in a given EntityType. Args: @@ -1655,8 +1644,10 @@ def list_features(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.ListFeaturesRequest. @@ -1678,40 +1669,31 @@ def list_features(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListFeaturesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_feature(self, - request: featurestore_service.UpdateFeatureRequest = None, - *, - feature: gca_feature.Feature = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_feature.Feature: + def update_feature( + self, + request: featurestore_service.UpdateFeatureRequest = None, + *, + feature: gca_feature.Feature = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_feature.Feature: r"""Updates the parameters of a single Feature. Args: @@ -1767,8 +1749,10 @@ def update_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.UpdateFeatureRequest. @@ -1792,30 +1776,26 @@ def update_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('feature.name', request.feature.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("feature.name", request.feature.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_feature(self, - request: featurestore_service.DeleteFeatureRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_feature( + self, + request: featurestore_service.DeleteFeatureRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a single Feature. Args: @@ -1861,8 +1841,10 @@ def delete_feature(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.DeleteFeatureRequest. @@ -1884,18 +1866,11 @@ def delete_feature(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1908,14 +1883,15 @@ def delete_feature(self, # Done; return the response. return response - def import_feature_values(self, - request: featurestore_service.ImportFeatureValuesRequest = None, - *, - entity_type: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def import_feature_values( + self, + request: featurestore_service.ImportFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Imports Feature values into the Featurestore from a source storage. The progress of the import is tracked by the returned @@ -1973,8 +1949,10 @@ def import_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.ImportFeatureValuesRequest. @@ -1996,18 +1974,13 @@ def import_feature_values(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('entity_type', request.entity_type), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -2020,14 +1993,15 @@ def import_feature_values(self, # Done; return the response. return response - def batch_read_feature_values(self, - request: featurestore_service.BatchReadFeatureValuesRequest = None, - *, - featurestore: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def batch_read_feature_values( + self, + request: featurestore_service.BatchReadFeatureValuesRequest = None, + *, + featurestore: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Batch reads Feature values from a Featurestore. This API enables batch reading Feature values, where each read instance in the batch may read Feature values @@ -2069,8 +2043,10 @@ def batch_read_feature_values(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.BatchReadFeatureValuesRequest. @@ -2087,23 +2063,20 @@ def batch_read_feature_values(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.batch_read_feature_values] + rpc = self._transport._wrapped_methods[ + self._transport.batch_read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('featurestore', request.featurestore), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("featurestore", request.featurestore),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -2116,14 +2089,15 @@ def batch_read_feature_values(self, # Done; return the response. return response - def search_features(self, - request: featurestore_service.SearchFeaturesRequest = None, - *, - location: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchFeaturesPager: + def search_features( + self, + request: featurestore_service.SearchFeaturesRequest = None, + *, + location: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchFeaturesPager: r"""Searches Features matching a query in a given project. @@ -2160,8 +2134,10 @@ def search_features(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([location]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a featurestore_service.SearchFeaturesRequest. @@ -2183,47 +2159,30 @@ def search_features(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('location', request.location), - )), + gapic_v1.routing_header.to_grpc_metadata((("location", request.location),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchFeaturesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'FeaturestoreServiceClient', -) +__all__ = ("FeaturestoreServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py index 7baa8e920c..98e6d56e17 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import entity_type from google.cloud.aiplatform_v1beta1.types import feature @@ -40,12 +49,15 @@ class ListFeaturestoresPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., featurestore_service.ListFeaturestoresResponse], - request: featurestore_service.ListFeaturestoresRequest, - response: featurestore_service.ListFeaturestoresResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., featurestore_service.ListFeaturestoresResponse], + request: featurestore_service.ListFeaturestoresRequest, + response: featurestore_service.ListFeaturestoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -79,7 +91,7 @@ def __iter__(self) -> Iterable[featurestore.Featurestore]: yield from page.featurestores def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListFeaturestoresAsyncPager: @@ -99,12 +111,17 @@ class ListFeaturestoresAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[featurestore_service.ListFeaturestoresResponse]], - request: featurestore_service.ListFeaturestoresRequest, - response: featurestore_service.ListFeaturestoresResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[featurestore_service.ListFeaturestoresResponse] + ], + request: featurestore_service.ListFeaturestoresRequest, + response: featurestore_service.ListFeaturestoresResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -126,7 +143,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[featurestore_service.ListFeaturestoresResponse]: + async def pages( + self, + ) -> AsyncIterable[featurestore_service.ListFeaturestoresResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -142,7 +161,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListEntityTypesPager: @@ -162,12 +181,15 @@ class ListEntityTypesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., featurestore_service.ListEntityTypesResponse], - request: featurestore_service.ListEntityTypesRequest, - response: featurestore_service.ListEntityTypesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., featurestore_service.ListEntityTypesResponse], + request: featurestore_service.ListEntityTypesRequest, + response: featurestore_service.ListEntityTypesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -201,7 +223,7 @@ def __iter__(self) -> Iterable[entity_type.EntityType]: yield from page.entity_types def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListEntityTypesAsyncPager: @@ -221,12 +243,15 @@ class ListEntityTypesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[featurestore_service.ListEntityTypesResponse]], - request: featurestore_service.ListEntityTypesRequest, - response: featurestore_service.ListEntityTypesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[featurestore_service.ListEntityTypesResponse]], + request: featurestore_service.ListEntityTypesRequest, + response: featurestore_service.ListEntityTypesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -248,7 +273,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[featurestore_service.ListEntityTypesResponse]: + async def pages( + self, + ) -> AsyncIterable[featurestore_service.ListEntityTypesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -264,7 +291,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListFeaturesPager: @@ -284,12 +311,15 @@ class ListFeaturesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., featurestore_service.ListFeaturesResponse], - request: featurestore_service.ListFeaturesRequest, - response: featurestore_service.ListFeaturesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., featurestore_service.ListFeaturesResponse], + request: featurestore_service.ListFeaturesRequest, + response: featurestore_service.ListFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -323,7 +353,7 @@ def __iter__(self) -> Iterable[feature.Feature]: yield from page.features def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListFeaturesAsyncPager: @@ -343,12 +373,15 @@ class ListFeaturesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[featurestore_service.ListFeaturesResponse]], - request: featurestore_service.ListFeaturesRequest, - response: featurestore_service.ListFeaturesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[featurestore_service.ListFeaturesResponse]], + request: featurestore_service.ListFeaturesRequest, + response: featurestore_service.ListFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -386,7 +419,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class SearchFeaturesPager: @@ -406,12 +439,15 @@ class SearchFeaturesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., featurestore_service.SearchFeaturesResponse], - request: featurestore_service.SearchFeaturesRequest, - response: featurestore_service.SearchFeaturesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., featurestore_service.SearchFeaturesResponse], + request: featurestore_service.SearchFeaturesRequest, + response: featurestore_service.SearchFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -445,7 +481,7 @@ def __iter__(self) -> Iterable[feature.Feature]: yield from page.features def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class SearchFeaturesAsyncPager: @@ -465,12 +501,15 @@ class SearchFeaturesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[featurestore_service.SearchFeaturesResponse]], - request: featurestore_service.SearchFeaturesRequest, - response: featurestore_service.SearchFeaturesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[featurestore_service.SearchFeaturesResponse]], + request: featurestore_service.SearchFeaturesRequest, + response: featurestore_service.SearchFeaturesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -508,4 +547,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py index 3fdc8aa3df..8f1772f264 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/__init__.py @@ -24,12 +24,14 @@ # Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[FeaturestoreServiceTransport]] -_transport_registry['grpc'] = FeaturestoreServiceGrpcTransport -_transport_registry['grpc_asyncio'] = FeaturestoreServiceGrpcAsyncIOTransport +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[FeaturestoreServiceTransport]] +_transport_registry["grpc"] = FeaturestoreServiceGrpcTransport +_transport_registry["grpc_asyncio"] = FeaturestoreServiceGrpcAsyncIOTransport __all__ = ( - 'FeaturestoreServiceTransport', - 'FeaturestoreServiceGrpcTransport', - 'FeaturestoreServiceGrpcAsyncIOTransport', + "FeaturestoreServiceTransport", + "FeaturestoreServiceGrpcTransport", + "FeaturestoreServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py index 4adf29e11b..2f633c4f81 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -38,29 +38,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class FeaturestoreServiceTransport(abc.ABC): """Abstract transport class for FeaturestoreService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -83,8 +83,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -93,17 +93,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -112,59 +114,37 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_featurestore: gapic_v1.method.wrap_method( - self.create_featurestore, - default_timeout=None, - client_info=client_info, + self.create_featurestore, default_timeout=None, client_info=client_info, ), self.get_featurestore: gapic_v1.method.wrap_method( - self.get_featurestore, - default_timeout=None, - client_info=client_info, + self.get_featurestore, default_timeout=None, client_info=client_info, ), self.list_featurestores: gapic_v1.method.wrap_method( - self.list_featurestores, - default_timeout=None, - client_info=client_info, + self.list_featurestores, default_timeout=None, client_info=client_info, ), self.update_featurestore: gapic_v1.method.wrap_method( - self.update_featurestore, - default_timeout=None, - client_info=client_info, + self.update_featurestore, default_timeout=None, client_info=client_info, ), self.delete_featurestore: gapic_v1.method.wrap_method( - self.delete_featurestore, - default_timeout=None, - client_info=client_info, + self.delete_featurestore, default_timeout=None, client_info=client_info, ), self.create_entity_type: gapic_v1.method.wrap_method( - self.create_entity_type, - default_timeout=None, - client_info=client_info, + self.create_entity_type, default_timeout=None, client_info=client_info, ), self.get_entity_type: gapic_v1.method.wrap_method( - self.get_entity_type, - default_timeout=None, - client_info=client_info, + self.get_entity_type, default_timeout=None, client_info=client_info, ), self.list_entity_types: gapic_v1.method.wrap_method( - self.list_entity_types, - default_timeout=None, - client_info=client_info, + self.list_entity_types, default_timeout=None, client_info=client_info, ), self.update_entity_type: gapic_v1.method.wrap_method( - self.update_entity_type, - default_timeout=None, - client_info=client_info, + self.update_entity_type, default_timeout=None, client_info=client_info, ), self.delete_entity_type: gapic_v1.method.wrap_method( - self.delete_entity_type, - default_timeout=None, - client_info=client_info, + self.delete_entity_type, default_timeout=None, client_info=client_info, ), self.create_feature: gapic_v1.method.wrap_method( - self.create_feature, - default_timeout=None, - client_info=client_info, + self.create_feature, default_timeout=None, client_info=client_info, ), self.batch_create_features: gapic_v1.method.wrap_method( self.batch_create_features, @@ -172,24 +152,16 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_feature: gapic_v1.method.wrap_method( - self.get_feature, - default_timeout=None, - client_info=client_info, + self.get_feature, default_timeout=None, client_info=client_info, ), self.list_features: gapic_v1.method.wrap_method( - self.list_features, - default_timeout=None, - client_info=client_info, + self.list_features, default_timeout=None, client_info=client_info, ), self.update_feature: gapic_v1.method.wrap_method( - self.update_feature, - default_timeout=None, - client_info=client_info, + self.update_feature, default_timeout=None, client_info=client_info, ), self.delete_feature: gapic_v1.method.wrap_method( - self.delete_feature, - default_timeout=None, - client_info=client_info, + self.delete_feature, default_timeout=None, client_info=client_info, ), self.import_feature_values: gapic_v1.method.wrap_method( self.import_feature_values, @@ -202,11 +174,8 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.search_features: gapic_v1.method.wrap_method( - self.search_features, - default_timeout=None, - client_info=client_info, + self.search_features, default_timeout=None, client_info=client_info, ), - } @property @@ -215,177 +184,191 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_featurestore(self) -> typing.Callable[ - [featurestore_service.CreateFeaturestoreRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_featurestore( + self, + ) -> typing.Callable[ + [featurestore_service.CreateFeaturestoreRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_featurestore(self) -> typing.Callable[ - [featurestore_service.GetFeaturestoreRequest], - typing.Union[ - featurestore.Featurestore, - typing.Awaitable[featurestore.Featurestore] - ]]: + def get_featurestore( + self, + ) -> typing.Callable[ + [featurestore_service.GetFeaturestoreRequest], + typing.Union[ + featurestore.Featurestore, typing.Awaitable[featurestore.Featurestore] + ], + ]: raise NotImplementedError() @property - def list_featurestores(self) -> typing.Callable[ - [featurestore_service.ListFeaturestoresRequest], - typing.Union[ - featurestore_service.ListFeaturestoresResponse, - typing.Awaitable[featurestore_service.ListFeaturestoresResponse] - ]]: + def list_featurestores( + self, + ) -> typing.Callable[ + [featurestore_service.ListFeaturestoresRequest], + typing.Union[ + featurestore_service.ListFeaturestoresResponse, + typing.Awaitable[featurestore_service.ListFeaturestoresResponse], + ], + ]: raise NotImplementedError() @property - def update_featurestore(self) -> typing.Callable[ - [featurestore_service.UpdateFeaturestoreRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def update_featurestore( + self, + ) -> typing.Callable[ + [featurestore_service.UpdateFeaturestoreRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def delete_featurestore(self) -> typing.Callable[ - [featurestore_service.DeleteFeaturestoreRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_featurestore( + self, + ) -> typing.Callable[ + [featurestore_service.DeleteFeaturestoreRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def create_entity_type(self) -> typing.Callable[ - [featurestore_service.CreateEntityTypeRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_entity_type( + self, + ) -> typing.Callable[ + [featurestore_service.CreateEntityTypeRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_entity_type(self) -> typing.Callable[ - [featurestore_service.GetEntityTypeRequest], - typing.Union[ - entity_type.EntityType, - typing.Awaitable[entity_type.EntityType] - ]]: + def get_entity_type( + self, + ) -> typing.Callable[ + [featurestore_service.GetEntityTypeRequest], + typing.Union[entity_type.EntityType, typing.Awaitable[entity_type.EntityType]], + ]: raise NotImplementedError() @property - def list_entity_types(self) -> typing.Callable[ - [featurestore_service.ListEntityTypesRequest], - typing.Union[ - featurestore_service.ListEntityTypesResponse, - typing.Awaitable[featurestore_service.ListEntityTypesResponse] - ]]: + def list_entity_types( + self, + ) -> typing.Callable[ + [featurestore_service.ListEntityTypesRequest], + typing.Union[ + featurestore_service.ListEntityTypesResponse, + typing.Awaitable[featurestore_service.ListEntityTypesResponse], + ], + ]: raise NotImplementedError() @property - def update_entity_type(self) -> typing.Callable[ - [featurestore_service.UpdateEntityTypeRequest], - typing.Union[ - gca_entity_type.EntityType, - typing.Awaitable[gca_entity_type.EntityType] - ]]: + def update_entity_type( + self, + ) -> typing.Callable[ + [featurestore_service.UpdateEntityTypeRequest], + typing.Union[ + gca_entity_type.EntityType, typing.Awaitable[gca_entity_type.EntityType] + ], + ]: raise NotImplementedError() @property - def delete_entity_type(self) -> typing.Callable[ - [featurestore_service.DeleteEntityTypeRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_entity_type( + self, + ) -> typing.Callable[ + [featurestore_service.DeleteEntityTypeRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def create_feature(self) -> typing.Callable[ - [featurestore_service.CreateFeatureRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_feature( + self, + ) -> typing.Callable[ + [featurestore_service.CreateFeatureRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def batch_create_features(self) -> typing.Callable[ - [featurestore_service.BatchCreateFeaturesRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def batch_create_features( + self, + ) -> typing.Callable[ + [featurestore_service.BatchCreateFeaturesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_feature(self) -> typing.Callable[ - [featurestore_service.GetFeatureRequest], - typing.Union[ - feature.Feature, - typing.Awaitable[feature.Feature] - ]]: + def get_feature( + self, + ) -> typing.Callable[ + [featurestore_service.GetFeatureRequest], + typing.Union[feature.Feature, typing.Awaitable[feature.Feature]], + ]: raise NotImplementedError() @property - def list_features(self) -> typing.Callable[ - [featurestore_service.ListFeaturesRequest], - typing.Union[ - featurestore_service.ListFeaturesResponse, - typing.Awaitable[featurestore_service.ListFeaturesResponse] - ]]: + def list_features( + self, + ) -> typing.Callable[ + [featurestore_service.ListFeaturesRequest], + typing.Union[ + featurestore_service.ListFeaturesResponse, + typing.Awaitable[featurestore_service.ListFeaturesResponse], + ], + ]: raise NotImplementedError() @property - def update_feature(self) -> typing.Callable[ - [featurestore_service.UpdateFeatureRequest], - typing.Union[ - gca_feature.Feature, - typing.Awaitable[gca_feature.Feature] - ]]: + def update_feature( + self, + ) -> typing.Callable[ + [featurestore_service.UpdateFeatureRequest], + typing.Union[gca_feature.Feature, typing.Awaitable[gca_feature.Feature]], + ]: raise NotImplementedError() @property - def delete_feature(self) -> typing.Callable[ - [featurestore_service.DeleteFeatureRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_feature( + self, + ) -> typing.Callable[ + [featurestore_service.DeleteFeatureRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def import_feature_values(self) -> typing.Callable[ - [featurestore_service.ImportFeatureValuesRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def import_feature_values( + self, + ) -> typing.Callable[ + [featurestore_service.ImportFeatureValuesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def batch_read_feature_values(self) -> typing.Callable[ - [featurestore_service.BatchReadFeatureValuesRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def batch_read_feature_values( + self, + ) -> typing.Callable[ + [featurestore_service.BatchReadFeatureValuesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def search_features(self) -> typing.Callable[ - [featurestore_service.SearchFeaturesRequest], - typing.Union[ - featurestore_service.SearchFeaturesResponse, - typing.Awaitable[featurestore_service.SearchFeaturesResponse] - ]]: + def search_features( + self, + ) -> typing.Callable[ + [featurestore_service.SearchFeaturesRequest], + typing.Union[ + featurestore_service.SearchFeaturesResponse, + typing.Awaitable[featurestore_service.SearchFeaturesResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'FeaturestoreServiceTransport', -) +__all__ = ("FeaturestoreServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py index 48fb007e78..ab15959efd 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,21 +51,24 @@ class FeaturestoreServiceGrpcTransport(FeaturestoreServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -177,13 +180,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -216,7 +221,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -234,17 +239,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_featurestore(self) -> Callable[ - [featurestore_service.CreateFeaturestoreRequest], - operations.Operation]: + def create_featurestore( + self, + ) -> Callable[ + [featurestore_service.CreateFeaturestoreRequest], operations.Operation + ]: r"""Return a callable for the create featurestore method over gRPC. Creates a new Featurestore in a given project and @@ -260,18 +265,20 @@ def create_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_featurestore' not in self._stubs: - self._stubs['create_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeaturestore', + if "create_featurestore" not in self._stubs: + self._stubs["create_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeaturestore", request_serializer=featurestore_service.CreateFeaturestoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_featurestore'] + return self._stubs["create_featurestore"] @property - def get_featurestore(self) -> Callable[ - [featurestore_service.GetFeaturestoreRequest], - featurestore.Featurestore]: + def get_featurestore( + self, + ) -> Callable[ + [featurestore_service.GetFeaturestoreRequest], featurestore.Featurestore + ]: r"""Return a callable for the get featurestore method over gRPC. Gets details of a single Featurestore. @@ -286,18 +293,21 @@ def get_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_featurestore' not in self._stubs: - self._stubs['get_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeaturestore', + if "get_featurestore" not in self._stubs: + self._stubs["get_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeaturestore", request_serializer=featurestore_service.GetFeaturestoreRequest.serialize, response_deserializer=featurestore.Featurestore.deserialize, ) - return self._stubs['get_featurestore'] + return self._stubs["get_featurestore"] @property - def list_featurestores(self) -> Callable[ - [featurestore_service.ListFeaturestoresRequest], - featurestore_service.ListFeaturestoresResponse]: + def list_featurestores( + self, + ) -> Callable[ + [featurestore_service.ListFeaturestoresRequest], + featurestore_service.ListFeaturestoresResponse, + ]: r"""Return a callable for the list featurestores method over gRPC. Lists Featurestores in a given project and location. @@ -312,18 +322,20 @@ def list_featurestores(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_featurestores' not in self._stubs: - self._stubs['list_featurestores'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeaturestores', + if "list_featurestores" not in self._stubs: + self._stubs["list_featurestores"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeaturestores", request_serializer=featurestore_service.ListFeaturestoresRequest.serialize, response_deserializer=featurestore_service.ListFeaturestoresResponse.deserialize, ) - return self._stubs['list_featurestores'] + return self._stubs["list_featurestores"] @property - def update_featurestore(self) -> Callable[ - [featurestore_service.UpdateFeaturestoreRequest], - operations.Operation]: + def update_featurestore( + self, + ) -> Callable[ + [featurestore_service.UpdateFeaturestoreRequest], operations.Operation + ]: r"""Return a callable for the update featurestore method over gRPC. Updates the parameters of a single Featurestore. @@ -338,18 +350,20 @@ def update_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_featurestore' not in self._stubs: - self._stubs['update_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeaturestore', + if "update_featurestore" not in self._stubs: + self._stubs["update_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeaturestore", request_serializer=featurestore_service.UpdateFeaturestoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_featurestore'] + return self._stubs["update_featurestore"] @property - def delete_featurestore(self) -> Callable[ - [featurestore_service.DeleteFeaturestoreRequest], - operations.Operation]: + def delete_featurestore( + self, + ) -> Callable[ + [featurestore_service.DeleteFeaturestoreRequest], operations.Operation + ]: r"""Return a callable for the delete featurestore method over gRPC. Deletes a single Featurestore. The Featurestore must not contain @@ -366,18 +380,18 @@ def delete_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_featurestore' not in self._stubs: - self._stubs['delete_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeaturestore', + if "delete_featurestore" not in self._stubs: + self._stubs["delete_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeaturestore", request_serializer=featurestore_service.DeleteFeaturestoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_featurestore'] + return self._stubs["delete_featurestore"] @property - def create_entity_type(self) -> Callable[ - [featurestore_service.CreateEntityTypeRequest], - operations.Operation]: + def create_entity_type( + self, + ) -> Callable[[featurestore_service.CreateEntityTypeRequest], operations.Operation]: r"""Return a callable for the create entity type method over gRPC. Creates a new EntityType in a given Featurestore. @@ -392,18 +406,18 @@ def create_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_entity_type' not in self._stubs: - self._stubs['create_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateEntityType', + if "create_entity_type" not in self._stubs: + self._stubs["create_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateEntityType", request_serializer=featurestore_service.CreateEntityTypeRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_entity_type'] + return self._stubs["create_entity_type"] @property - def get_entity_type(self) -> Callable[ - [featurestore_service.GetEntityTypeRequest], - entity_type.EntityType]: + def get_entity_type( + self, + ) -> Callable[[featurestore_service.GetEntityTypeRequest], entity_type.EntityType]: r"""Return a callable for the get entity type method over gRPC. Gets details of a single EntityType. @@ -418,18 +432,21 @@ def get_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_entity_type' not in self._stubs: - self._stubs['get_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetEntityType', + if "get_entity_type" not in self._stubs: + self._stubs["get_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetEntityType", request_serializer=featurestore_service.GetEntityTypeRequest.serialize, response_deserializer=entity_type.EntityType.deserialize, ) - return self._stubs['get_entity_type'] + return self._stubs["get_entity_type"] @property - def list_entity_types(self) -> Callable[ - [featurestore_service.ListEntityTypesRequest], - featurestore_service.ListEntityTypesResponse]: + def list_entity_types( + self, + ) -> Callable[ + [featurestore_service.ListEntityTypesRequest], + featurestore_service.ListEntityTypesResponse, + ]: r"""Return a callable for the list entity types method over gRPC. Lists EntityTypes in a given Featurestore. @@ -444,18 +461,20 @@ def list_entity_types(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_entity_types' not in self._stubs: - self._stubs['list_entity_types'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListEntityTypes', + if "list_entity_types" not in self._stubs: + self._stubs["list_entity_types"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListEntityTypes", request_serializer=featurestore_service.ListEntityTypesRequest.serialize, response_deserializer=featurestore_service.ListEntityTypesResponse.deserialize, ) - return self._stubs['list_entity_types'] + return self._stubs["list_entity_types"] @property - def update_entity_type(self) -> Callable[ - [featurestore_service.UpdateEntityTypeRequest], - gca_entity_type.EntityType]: + def update_entity_type( + self, + ) -> Callable[ + [featurestore_service.UpdateEntityTypeRequest], gca_entity_type.EntityType + ]: r"""Return a callable for the update entity type method over gRPC. Updates the parameters of a single EntityType. @@ -470,18 +489,18 @@ def update_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_entity_type' not in self._stubs: - self._stubs['update_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateEntityType', + if "update_entity_type" not in self._stubs: + self._stubs["update_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateEntityType", request_serializer=featurestore_service.UpdateEntityTypeRequest.serialize, response_deserializer=gca_entity_type.EntityType.deserialize, ) - return self._stubs['update_entity_type'] + return self._stubs["update_entity_type"] @property - def delete_entity_type(self) -> Callable[ - [featurestore_service.DeleteEntityTypeRequest], - operations.Operation]: + def delete_entity_type( + self, + ) -> Callable[[featurestore_service.DeleteEntityTypeRequest], operations.Operation]: r"""Return a callable for the delete entity type method over gRPC. Deletes a single EntityType. The EntityType must not have any @@ -498,18 +517,18 @@ def delete_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_entity_type' not in self._stubs: - self._stubs['delete_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteEntityType', + if "delete_entity_type" not in self._stubs: + self._stubs["delete_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteEntityType", request_serializer=featurestore_service.DeleteEntityTypeRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_entity_type'] + return self._stubs["delete_entity_type"] @property - def create_feature(self) -> Callable[ - [featurestore_service.CreateFeatureRequest], - operations.Operation]: + def create_feature( + self, + ) -> Callable[[featurestore_service.CreateFeatureRequest], operations.Operation]: r"""Return a callable for the create feature method over gRPC. Creates a new Feature in a given EntityType. @@ -524,18 +543,20 @@ def create_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_feature' not in self._stubs: - self._stubs['create_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeature', + if "create_feature" not in self._stubs: + self._stubs["create_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeature", request_serializer=featurestore_service.CreateFeatureRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_feature'] + return self._stubs["create_feature"] @property - def batch_create_features(self) -> Callable[ - [featurestore_service.BatchCreateFeaturesRequest], - operations.Operation]: + def batch_create_features( + self, + ) -> Callable[ + [featurestore_service.BatchCreateFeaturesRequest], operations.Operation + ]: r"""Return a callable for the batch create features method over gRPC. Creates a batch of Features in a given EntityType. @@ -550,18 +571,18 @@ def batch_create_features(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_create_features' not in self._stubs: - self._stubs['batch_create_features'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchCreateFeatures', + if "batch_create_features" not in self._stubs: + self._stubs["batch_create_features"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchCreateFeatures", request_serializer=featurestore_service.BatchCreateFeaturesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_create_features'] + return self._stubs["batch_create_features"] @property - def get_feature(self) -> Callable[ - [featurestore_service.GetFeatureRequest], - feature.Feature]: + def get_feature( + self, + ) -> Callable[[featurestore_service.GetFeatureRequest], feature.Feature]: r"""Return a callable for the get feature method over gRPC. Gets details of a single Feature. @@ -576,18 +597,21 @@ def get_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_feature' not in self._stubs: - self._stubs['get_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeature', + if "get_feature" not in self._stubs: + self._stubs["get_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeature", request_serializer=featurestore_service.GetFeatureRequest.serialize, response_deserializer=feature.Feature.deserialize, ) - return self._stubs['get_feature'] + return self._stubs["get_feature"] @property - def list_features(self) -> Callable[ - [featurestore_service.ListFeaturesRequest], - featurestore_service.ListFeaturesResponse]: + def list_features( + self, + ) -> Callable[ + [featurestore_service.ListFeaturesRequest], + featurestore_service.ListFeaturesResponse, + ]: r"""Return a callable for the list features method over gRPC. Lists Features in a given EntityType. @@ -602,18 +626,18 @@ def list_features(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_features' not in self._stubs: - self._stubs['list_features'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeatures', + if "list_features" not in self._stubs: + self._stubs["list_features"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeatures", request_serializer=featurestore_service.ListFeaturesRequest.serialize, response_deserializer=featurestore_service.ListFeaturesResponse.deserialize, ) - return self._stubs['list_features'] + return self._stubs["list_features"] @property - def update_feature(self) -> Callable[ - [featurestore_service.UpdateFeatureRequest], - gca_feature.Feature]: + def update_feature( + self, + ) -> Callable[[featurestore_service.UpdateFeatureRequest], gca_feature.Feature]: r"""Return a callable for the update feature method over gRPC. Updates the parameters of a single Feature. @@ -628,18 +652,18 @@ def update_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_feature' not in self._stubs: - self._stubs['update_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeature', + if "update_feature" not in self._stubs: + self._stubs["update_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeature", request_serializer=featurestore_service.UpdateFeatureRequest.serialize, response_deserializer=gca_feature.Feature.deserialize, ) - return self._stubs['update_feature'] + return self._stubs["update_feature"] @property - def delete_feature(self) -> Callable[ - [featurestore_service.DeleteFeatureRequest], - operations.Operation]: + def delete_feature( + self, + ) -> Callable[[featurestore_service.DeleteFeatureRequest], operations.Operation]: r"""Return a callable for the delete feature method over gRPC. Deletes a single Feature. @@ -654,18 +678,20 @@ def delete_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_feature' not in self._stubs: - self._stubs['delete_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeature', + if "delete_feature" not in self._stubs: + self._stubs["delete_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeature", request_serializer=featurestore_service.DeleteFeatureRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_feature'] + return self._stubs["delete_feature"] @property - def import_feature_values(self) -> Callable[ - [featurestore_service.ImportFeatureValuesRequest], - operations.Operation]: + def import_feature_values( + self, + ) -> Callable[ + [featurestore_service.ImportFeatureValuesRequest], operations.Operation + ]: r"""Return a callable for the import feature values method over gRPC. Imports Feature values into the Featurestore from a @@ -700,18 +726,20 @@ def import_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_feature_values' not in self._stubs: - self._stubs['import_feature_values'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ImportFeatureValues', + if "import_feature_values" not in self._stubs: + self._stubs["import_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ImportFeatureValues", request_serializer=featurestore_service.ImportFeatureValuesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_feature_values'] + return self._stubs["import_feature_values"] @property - def batch_read_feature_values(self) -> Callable[ - [featurestore_service.BatchReadFeatureValuesRequest], - operations.Operation]: + def batch_read_feature_values( + self, + ) -> Callable[ + [featurestore_service.BatchReadFeatureValuesRequest], operations.Operation + ]: r"""Return a callable for the batch read feature values method over gRPC. Batch reads Feature values from a Featurestore. @@ -731,18 +759,21 @@ def batch_read_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_read_feature_values' not in self._stubs: - self._stubs['batch_read_feature_values'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchReadFeatureValues', + if "batch_read_feature_values" not in self._stubs: + self._stubs["batch_read_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchReadFeatureValues", request_serializer=featurestore_service.BatchReadFeatureValuesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_read_feature_values'] + return self._stubs["batch_read_feature_values"] @property - def search_features(self) -> Callable[ - [featurestore_service.SearchFeaturesRequest], - featurestore_service.SearchFeaturesResponse]: + def search_features( + self, + ) -> Callable[ + [featurestore_service.SearchFeaturesRequest], + featurestore_service.SearchFeaturesResponse, + ]: r"""Return a callable for the search features method over gRPC. Searches Features matching a query in a given @@ -758,15 +789,13 @@ def search_features(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_features' not in self._stubs: - self._stubs['search_features'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/SearchFeatures', + if "search_features" not in self._stubs: + self._stubs["search_features"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/SearchFeatures", request_serializer=featurestore_service.SearchFeaturesRequest.serialize, response_deserializer=featurestore_service.SearchFeaturesResponse.deserialize, ) - return self._stubs['search_features'] + return self._stubs["search_features"] -__all__ = ( - 'FeaturestoreServiceGrpcTransport', -) +__all__ = ("FeaturestoreServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py index 97114a68be..e0a4e35394 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import entity_type @@ -58,13 +58,15 @@ class FeaturestoreServiceGrpcAsyncIOTransport(FeaturestoreServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -93,22 +95,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -247,9 +251,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_featurestore(self) -> Callable[ - [featurestore_service.CreateFeaturestoreRequest], - Awaitable[operations.Operation]]: + def create_featurestore( + self, + ) -> Callable[ + [featurestore_service.CreateFeaturestoreRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the create featurestore method over gRPC. Creates a new Featurestore in a given project and @@ -265,18 +272,21 @@ def create_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_featurestore' not in self._stubs: - self._stubs['create_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeaturestore', + if "create_featurestore" not in self._stubs: + self._stubs["create_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeaturestore", request_serializer=featurestore_service.CreateFeaturestoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_featurestore'] + return self._stubs["create_featurestore"] @property - def get_featurestore(self) -> Callable[ - [featurestore_service.GetFeaturestoreRequest], - Awaitable[featurestore.Featurestore]]: + def get_featurestore( + self, + ) -> Callable[ + [featurestore_service.GetFeaturestoreRequest], + Awaitable[featurestore.Featurestore], + ]: r"""Return a callable for the get featurestore method over gRPC. Gets details of a single Featurestore. @@ -291,18 +301,21 @@ def get_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_featurestore' not in self._stubs: - self._stubs['get_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeaturestore', + if "get_featurestore" not in self._stubs: + self._stubs["get_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeaturestore", request_serializer=featurestore_service.GetFeaturestoreRequest.serialize, response_deserializer=featurestore.Featurestore.deserialize, ) - return self._stubs['get_featurestore'] + return self._stubs["get_featurestore"] @property - def list_featurestores(self) -> Callable[ - [featurestore_service.ListFeaturestoresRequest], - Awaitable[featurestore_service.ListFeaturestoresResponse]]: + def list_featurestores( + self, + ) -> Callable[ + [featurestore_service.ListFeaturestoresRequest], + Awaitable[featurestore_service.ListFeaturestoresResponse], + ]: r"""Return a callable for the list featurestores method over gRPC. Lists Featurestores in a given project and location. @@ -317,18 +330,21 @@ def list_featurestores(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_featurestores' not in self._stubs: - self._stubs['list_featurestores'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeaturestores', + if "list_featurestores" not in self._stubs: + self._stubs["list_featurestores"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeaturestores", request_serializer=featurestore_service.ListFeaturestoresRequest.serialize, response_deserializer=featurestore_service.ListFeaturestoresResponse.deserialize, ) - return self._stubs['list_featurestores'] + return self._stubs["list_featurestores"] @property - def update_featurestore(self) -> Callable[ - [featurestore_service.UpdateFeaturestoreRequest], - Awaitable[operations.Operation]]: + def update_featurestore( + self, + ) -> Callable[ + [featurestore_service.UpdateFeaturestoreRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the update featurestore method over gRPC. Updates the parameters of a single Featurestore. @@ -343,18 +359,21 @@ def update_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_featurestore' not in self._stubs: - self._stubs['update_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeaturestore', + if "update_featurestore" not in self._stubs: + self._stubs["update_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeaturestore", request_serializer=featurestore_service.UpdateFeaturestoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_featurestore'] + return self._stubs["update_featurestore"] @property - def delete_featurestore(self) -> Callable[ - [featurestore_service.DeleteFeaturestoreRequest], - Awaitable[operations.Operation]]: + def delete_featurestore( + self, + ) -> Callable[ + [featurestore_service.DeleteFeaturestoreRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete featurestore method over gRPC. Deletes a single Featurestore. The Featurestore must not contain @@ -371,18 +390,20 @@ def delete_featurestore(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_featurestore' not in self._stubs: - self._stubs['delete_featurestore'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeaturestore', + if "delete_featurestore" not in self._stubs: + self._stubs["delete_featurestore"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeaturestore", request_serializer=featurestore_service.DeleteFeaturestoreRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_featurestore'] + return self._stubs["delete_featurestore"] @property - def create_entity_type(self) -> Callable[ - [featurestore_service.CreateEntityTypeRequest], - Awaitable[operations.Operation]]: + def create_entity_type( + self, + ) -> Callable[ + [featurestore_service.CreateEntityTypeRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create entity type method over gRPC. Creates a new EntityType in a given Featurestore. @@ -397,18 +418,20 @@ def create_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_entity_type' not in self._stubs: - self._stubs['create_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateEntityType', + if "create_entity_type" not in self._stubs: + self._stubs["create_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateEntityType", request_serializer=featurestore_service.CreateEntityTypeRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_entity_type'] + return self._stubs["create_entity_type"] @property - def get_entity_type(self) -> Callable[ - [featurestore_service.GetEntityTypeRequest], - Awaitable[entity_type.EntityType]]: + def get_entity_type( + self, + ) -> Callable[ + [featurestore_service.GetEntityTypeRequest], Awaitable[entity_type.EntityType] + ]: r"""Return a callable for the get entity type method over gRPC. Gets details of a single EntityType. @@ -423,18 +446,21 @@ def get_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_entity_type' not in self._stubs: - self._stubs['get_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetEntityType', + if "get_entity_type" not in self._stubs: + self._stubs["get_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetEntityType", request_serializer=featurestore_service.GetEntityTypeRequest.serialize, response_deserializer=entity_type.EntityType.deserialize, ) - return self._stubs['get_entity_type'] + return self._stubs["get_entity_type"] @property - def list_entity_types(self) -> Callable[ - [featurestore_service.ListEntityTypesRequest], - Awaitable[featurestore_service.ListEntityTypesResponse]]: + def list_entity_types( + self, + ) -> Callable[ + [featurestore_service.ListEntityTypesRequest], + Awaitable[featurestore_service.ListEntityTypesResponse], + ]: r"""Return a callable for the list entity types method over gRPC. Lists EntityTypes in a given Featurestore. @@ -449,18 +475,21 @@ def list_entity_types(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_entity_types' not in self._stubs: - self._stubs['list_entity_types'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListEntityTypes', + if "list_entity_types" not in self._stubs: + self._stubs["list_entity_types"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListEntityTypes", request_serializer=featurestore_service.ListEntityTypesRequest.serialize, response_deserializer=featurestore_service.ListEntityTypesResponse.deserialize, ) - return self._stubs['list_entity_types'] + return self._stubs["list_entity_types"] @property - def update_entity_type(self) -> Callable[ - [featurestore_service.UpdateEntityTypeRequest], - Awaitable[gca_entity_type.EntityType]]: + def update_entity_type( + self, + ) -> Callable[ + [featurestore_service.UpdateEntityTypeRequest], + Awaitable[gca_entity_type.EntityType], + ]: r"""Return a callable for the update entity type method over gRPC. Updates the parameters of a single EntityType. @@ -475,18 +504,20 @@ def update_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_entity_type' not in self._stubs: - self._stubs['update_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateEntityType', + if "update_entity_type" not in self._stubs: + self._stubs["update_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateEntityType", request_serializer=featurestore_service.UpdateEntityTypeRequest.serialize, response_deserializer=gca_entity_type.EntityType.deserialize, ) - return self._stubs['update_entity_type'] + return self._stubs["update_entity_type"] @property - def delete_entity_type(self) -> Callable[ - [featurestore_service.DeleteEntityTypeRequest], - Awaitable[operations.Operation]]: + def delete_entity_type( + self, + ) -> Callable[ + [featurestore_service.DeleteEntityTypeRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete entity type method over gRPC. Deletes a single EntityType. The EntityType must not have any @@ -503,18 +534,20 @@ def delete_entity_type(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_entity_type' not in self._stubs: - self._stubs['delete_entity_type'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteEntityType', + if "delete_entity_type" not in self._stubs: + self._stubs["delete_entity_type"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteEntityType", request_serializer=featurestore_service.DeleteEntityTypeRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_entity_type'] + return self._stubs["delete_entity_type"] @property - def create_feature(self) -> Callable[ - [featurestore_service.CreateFeatureRequest], - Awaitable[operations.Operation]]: + def create_feature( + self, + ) -> Callable[ + [featurestore_service.CreateFeatureRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create feature method over gRPC. Creates a new Feature in a given EntityType. @@ -529,18 +562,21 @@ def create_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_feature' not in self._stubs: - self._stubs['create_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeature', + if "create_feature" not in self._stubs: + self._stubs["create_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/CreateFeature", request_serializer=featurestore_service.CreateFeatureRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_feature'] + return self._stubs["create_feature"] @property - def batch_create_features(self) -> Callable[ - [featurestore_service.BatchCreateFeaturesRequest], - Awaitable[operations.Operation]]: + def batch_create_features( + self, + ) -> Callable[ + [featurestore_service.BatchCreateFeaturesRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the batch create features method over gRPC. Creates a batch of Features in a given EntityType. @@ -555,18 +591,18 @@ def batch_create_features(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_create_features' not in self._stubs: - self._stubs['batch_create_features'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchCreateFeatures', + if "batch_create_features" not in self._stubs: + self._stubs["batch_create_features"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchCreateFeatures", request_serializer=featurestore_service.BatchCreateFeaturesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_create_features'] + return self._stubs["batch_create_features"] @property - def get_feature(self) -> Callable[ - [featurestore_service.GetFeatureRequest], - Awaitable[feature.Feature]]: + def get_feature( + self, + ) -> Callable[[featurestore_service.GetFeatureRequest], Awaitable[feature.Feature]]: r"""Return a callable for the get feature method over gRPC. Gets details of a single Feature. @@ -581,18 +617,21 @@ def get_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_feature' not in self._stubs: - self._stubs['get_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeature', + if "get_feature" not in self._stubs: + self._stubs["get_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/GetFeature", request_serializer=featurestore_service.GetFeatureRequest.serialize, response_deserializer=feature.Feature.deserialize, ) - return self._stubs['get_feature'] + return self._stubs["get_feature"] @property - def list_features(self) -> Callable[ - [featurestore_service.ListFeaturesRequest], - Awaitable[featurestore_service.ListFeaturesResponse]]: + def list_features( + self, + ) -> Callable[ + [featurestore_service.ListFeaturesRequest], + Awaitable[featurestore_service.ListFeaturesResponse], + ]: r"""Return a callable for the list features method over gRPC. Lists Features in a given EntityType. @@ -607,18 +646,20 @@ def list_features(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_features' not in self._stubs: - self._stubs['list_features'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeatures', + if "list_features" not in self._stubs: + self._stubs["list_features"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ListFeatures", request_serializer=featurestore_service.ListFeaturesRequest.serialize, response_deserializer=featurestore_service.ListFeaturesResponse.deserialize, ) - return self._stubs['list_features'] + return self._stubs["list_features"] @property - def update_feature(self) -> Callable[ - [featurestore_service.UpdateFeatureRequest], - Awaitable[gca_feature.Feature]]: + def update_feature( + self, + ) -> Callable[ + [featurestore_service.UpdateFeatureRequest], Awaitable[gca_feature.Feature] + ]: r"""Return a callable for the update feature method over gRPC. Updates the parameters of a single Feature. @@ -633,18 +674,20 @@ def update_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_feature' not in self._stubs: - self._stubs['update_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeature', + if "update_feature" not in self._stubs: + self._stubs["update_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/UpdateFeature", request_serializer=featurestore_service.UpdateFeatureRequest.serialize, response_deserializer=gca_feature.Feature.deserialize, ) - return self._stubs['update_feature'] + return self._stubs["update_feature"] @property - def delete_feature(self) -> Callable[ - [featurestore_service.DeleteFeatureRequest], - Awaitable[operations.Operation]]: + def delete_feature( + self, + ) -> Callable[ + [featurestore_service.DeleteFeatureRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete feature method over gRPC. Deletes a single Feature. @@ -659,18 +702,21 @@ def delete_feature(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_feature' not in self._stubs: - self._stubs['delete_feature'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeature', + if "delete_feature" not in self._stubs: + self._stubs["delete_feature"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/DeleteFeature", request_serializer=featurestore_service.DeleteFeatureRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_feature'] + return self._stubs["delete_feature"] @property - def import_feature_values(self) -> Callable[ - [featurestore_service.ImportFeatureValuesRequest], - Awaitable[operations.Operation]]: + def import_feature_values( + self, + ) -> Callable[ + [featurestore_service.ImportFeatureValuesRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the import feature values method over gRPC. Imports Feature values into the Featurestore from a @@ -705,18 +751,21 @@ def import_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_feature_values' not in self._stubs: - self._stubs['import_feature_values'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/ImportFeatureValues', + if "import_feature_values" not in self._stubs: + self._stubs["import_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ImportFeatureValues", request_serializer=featurestore_service.ImportFeatureValuesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_feature_values'] + return self._stubs["import_feature_values"] @property - def batch_read_feature_values(self) -> Callable[ - [featurestore_service.BatchReadFeatureValuesRequest], - Awaitable[operations.Operation]]: + def batch_read_feature_values( + self, + ) -> Callable[ + [featurestore_service.BatchReadFeatureValuesRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the batch read feature values method over gRPC. Batch reads Feature values from a Featurestore. @@ -736,18 +785,21 @@ def batch_read_feature_values(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_read_feature_values' not in self._stubs: - self._stubs['batch_read_feature_values'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchReadFeatureValues', + if "batch_read_feature_values" not in self._stubs: + self._stubs["batch_read_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/BatchReadFeatureValues", request_serializer=featurestore_service.BatchReadFeatureValuesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_read_feature_values'] + return self._stubs["batch_read_feature_values"] @property - def search_features(self) -> Callable[ - [featurestore_service.SearchFeaturesRequest], - Awaitable[featurestore_service.SearchFeaturesResponse]]: + def search_features( + self, + ) -> Callable[ + [featurestore_service.SearchFeaturesRequest], + Awaitable[featurestore_service.SearchFeaturesResponse], + ]: r"""Return a callable for the search features method over gRPC. Searches Features matching a query in a given @@ -763,15 +815,13 @@ def search_features(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_features' not in self._stubs: - self._stubs['search_features'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.FeaturestoreService/SearchFeatures', + if "search_features" not in self._stubs: + self._stubs["search_features"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/SearchFeatures", request_serializer=featurestore_service.SearchFeaturesRequest.serialize, response_deserializer=featurestore_service.SearchFeaturesResponse.deserialize, ) - return self._stubs['search_features'] + return self._stubs["search_features"] -__all__ = ( - 'FeaturestoreServiceGrpcAsyncIOTransport', -) +__all__ = ("FeaturestoreServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py index 853d7b928c..1eeda9dcdd 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import IndexEndpointServiceAsyncClient __all__ = ( - 'IndexEndpointServiceClient', - 'IndexEndpointServiceAsyncClient', + "IndexEndpointServiceClient", + "IndexEndpointServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py index 704dd1fda4..06dd2a9d72 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -55,24 +55,42 @@ class IndexEndpointServiceAsyncClient: index_path = staticmethod(IndexEndpointServiceClient.index_path) parse_index_path = staticmethod(IndexEndpointServiceClient.parse_index_path) index_endpoint_path = staticmethod(IndexEndpointServiceClient.index_endpoint_path) - parse_index_endpoint_path = staticmethod(IndexEndpointServiceClient.parse_index_endpoint_path) + parse_index_endpoint_path = staticmethod( + IndexEndpointServiceClient.parse_index_endpoint_path + ) index_endpoint_path = staticmethod(IndexEndpointServiceClient.index_endpoint_path) - parse_index_endpoint_path = staticmethod(IndexEndpointServiceClient.parse_index_endpoint_path) + parse_index_endpoint_path = staticmethod( + IndexEndpointServiceClient.parse_index_endpoint_path + ) - common_billing_account_path = staticmethod(IndexEndpointServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(IndexEndpointServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + IndexEndpointServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + IndexEndpointServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(IndexEndpointServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(IndexEndpointServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + IndexEndpointServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(IndexEndpointServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(IndexEndpointServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + IndexEndpointServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + IndexEndpointServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(IndexEndpointServiceClient.common_project_path) - parse_common_project_path = staticmethod(IndexEndpointServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + IndexEndpointServiceClient.parse_common_project_path + ) common_location_path = staticmethod(IndexEndpointServiceClient.common_location_path) - parse_common_location_path = staticmethod(IndexEndpointServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + IndexEndpointServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -115,14 +133,19 @@ def transport(self) -> IndexEndpointServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(IndexEndpointServiceClient).get_transport_class, type(IndexEndpointServiceClient)) + get_transport_class = functools.partial( + type(IndexEndpointServiceClient).get_transport_class, + type(IndexEndpointServiceClient), + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, IndexEndpointServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, IndexEndpointServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the index endpoint service client. Args: @@ -161,18 +184,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_index_endpoint(self, - request: index_endpoint_service.CreateIndexEndpointRequest = None, - *, - parent: str = None, - index_endpoint: gca_index_endpoint.IndexEndpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_index_endpoint( + self, + request: index_endpoint_service.CreateIndexEndpointRequest = None, + *, + parent: str = None, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates an IndexEndpoint. Args: @@ -214,8 +237,10 @@ async def create_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_endpoint_service.CreateIndexEndpointRequest(request) @@ -238,18 +263,11 @@ async def create_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -262,14 +280,15 @@ async def create_index_endpoint(self, # Done; return the response. return response - async def get_index_endpoint(self, - request: index_endpoint_service.GetIndexEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> index_endpoint.IndexEndpoint: + async def get_index_endpoint( + self, + request: index_endpoint_service.GetIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index_endpoint.IndexEndpoint: r"""Gets an IndexEndpoint. Args: @@ -303,8 +322,10 @@ async def get_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_endpoint_service.GetIndexEndpointRequest(request) @@ -325,30 +346,24 @@ async def get_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_index_endpoints(self, - request: index_endpoint_service.ListIndexEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListIndexEndpointsAsyncPager: + async def list_index_endpoints( + self, + request: index_endpoint_service.ListIndexEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexEndpointsAsyncPager: r"""Lists IndexEndpoints in a Location. Args: @@ -384,8 +399,10 @@ async def list_index_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_endpoint_service.ListIndexEndpointsRequest(request) @@ -406,40 +423,31 @@ async def list_index_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListIndexEndpointsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_index_endpoint(self, - request: index_endpoint_service.UpdateIndexEndpointRequest = None, - *, - index_endpoint: gca_index_endpoint.IndexEndpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_index_endpoint.IndexEndpoint: + async def update_index_endpoint( + self, + request: index_endpoint_service.UpdateIndexEndpointRequest = None, + *, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_index_endpoint.IndexEndpoint: r"""Updates an IndexEndpoint. Args: @@ -479,8 +487,10 @@ async def update_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_endpoint_service.UpdateIndexEndpointRequest(request) @@ -503,30 +513,26 @@ async def update_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index_endpoint.name', request.index_endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index_endpoint.name", request.index_endpoint.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_index_endpoint(self, - request: index_endpoint_service.DeleteIndexEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_index_endpoint( + self, + request: index_endpoint_service.DeleteIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes an IndexEndpoint. Args: @@ -572,8 +578,10 @@ async def delete_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_endpoint_service.DeleteIndexEndpointRequest(request) @@ -594,18 +602,11 @@ async def delete_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -618,15 +619,16 @@ async def delete_index_endpoint(self, # Done; return the response. return response - async def deploy_index(self, - request: index_endpoint_service.DeployIndexRequest = None, - *, - index_endpoint: str = None, - deployed_index: gca_index_endpoint.DeployedIndex = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def deploy_index( + self, + request: index_endpoint_service.DeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index: gca_index_endpoint.DeployedIndex = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deploys an Index into this IndexEndpoint, creating a DeployedIndex within it. Only non-empty Indexes can be deployed. @@ -672,8 +674,10 @@ async def deploy_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_endpoint_service.DeployIndexRequest(request) @@ -696,18 +700,13 @@ async def deploy_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index_endpoint', request.index_endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index_endpoint", request.index_endpoint),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -720,15 +719,16 @@ async def deploy_index(self, # Done; return the response. return response - async def undeploy_index(self, - request: index_endpoint_service.UndeployIndexRequest = None, - *, - index_endpoint: str = None, - deployed_index_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def undeploy_index( + self, + request: index_endpoint_service.UndeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Undeploys an Index from an IndexEndpoint, removing a DeployedIndex from it, and freeing all resources it's using. @@ -774,8 +774,10 @@ async def undeploy_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_endpoint_service.UndeployIndexRequest(request) @@ -798,18 +800,13 @@ async def undeploy_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index_endpoint', request.index_endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index_endpoint", request.index_endpoint),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -823,21 +820,14 @@ async def undeploy_index(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'IndexEndpointServiceAsyncClient', -) +__all__ = ("IndexEndpointServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py index 9933c45371..373410e6e7 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -55,13 +55,16 @@ class IndexEndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[IndexEndpointServiceTransport]] - _transport_registry['grpc'] = IndexEndpointServiceGrpcTransport - _transport_registry['grpc_asyncio'] = IndexEndpointServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[IndexEndpointServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[IndexEndpointServiceTransport]] + _transport_registry["grpc"] = IndexEndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = IndexEndpointServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[IndexEndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -112,7 +115,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -147,9 +150,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: IndexEndpointServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -164,99 +166,120 @@ def transport(self) -> IndexEndpointServiceTransport: return self._transport @staticmethod - def index_path(project: str,location: str,index: str,) -> str: + def index_path(project: str, location: str, index: str,) -> str: """Return a fully-qualified index string.""" - return "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + return "projects/{project}/locations/{location}/indexes/{index}".format( + project=project, location=location, index=index, + ) @staticmethod - def parse_index_path(path: str) -> Dict[str,str]: + def parse_index_path(path: str) -> Dict[str, str]: """Parse a index path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexes/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/indexes/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def index_endpoint_path(project: str,location: str,index_endpoint: str,) -> str: + def index_endpoint_path(project: str, location: str, index_endpoint: str,) -> str: """Return a fully-qualified index_endpoint string.""" - return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( + project=project, location=location, index_endpoint=index_endpoint, + ) @staticmethod - def parse_index_endpoint_path(path: str) -> Dict[str,str]: + def parse_index_endpoint_path(path: str) -> Dict[str, str]: """Parse a index_endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def index_endpoint_path(project: str,location: str,index_endpoint: str,) -> str: + def index_endpoint_path(project: str, location: str, index_endpoint: str,) -> str: """Return a fully-qualified index_endpoint string.""" - return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( + project=project, location=location, index_endpoint=index_endpoint, + ) @staticmethod - def parse_index_endpoint_path(path: str) -> Dict[str,str]: + def parse_index_endpoint_path(path: str) -> Dict[str, str]: """Parse a index_endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, IndexEndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, IndexEndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the index endpoint service client. Args: @@ -300,7 +323,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -310,7 +335,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -322,7 +349,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -334,8 +363,10 @@ def __init__(self, *, if isinstance(transport, IndexEndpointServiceTransport): # transport is a IndexEndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -354,15 +385,16 @@ def __init__(self, *, client_info=client_info, ) - def create_index_endpoint(self, - request: index_endpoint_service.CreateIndexEndpointRequest = None, - *, - parent: str = None, - index_endpoint: gca_index_endpoint.IndexEndpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_index_endpoint( + self, + request: index_endpoint_service.CreateIndexEndpointRequest = None, + *, + parent: str = None, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates an IndexEndpoint. Args: @@ -404,8 +436,10 @@ def create_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_endpoint_service.CreateIndexEndpointRequest. @@ -429,18 +463,11 @@ def create_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -453,14 +480,15 @@ def create_index_endpoint(self, # Done; return the response. return response - def get_index_endpoint(self, - request: index_endpoint_service.GetIndexEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> index_endpoint.IndexEndpoint: + def get_index_endpoint( + self, + request: index_endpoint_service.GetIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index_endpoint.IndexEndpoint: r"""Gets an IndexEndpoint. Args: @@ -494,8 +522,10 @@ def get_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_endpoint_service.GetIndexEndpointRequest. @@ -517,30 +547,24 @@ def get_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_index_endpoints(self, - request: index_endpoint_service.ListIndexEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListIndexEndpointsPager: + def list_index_endpoints( + self, + request: index_endpoint_service.ListIndexEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexEndpointsPager: r"""Lists IndexEndpoints in a Location. Args: @@ -576,8 +600,10 @@ def list_index_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_endpoint_service.ListIndexEndpointsRequest. @@ -599,40 +625,31 @@ def list_index_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListIndexEndpointsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_index_endpoint(self, - request: index_endpoint_service.UpdateIndexEndpointRequest = None, - *, - index_endpoint: gca_index_endpoint.IndexEndpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_index_endpoint.IndexEndpoint: + def update_index_endpoint( + self, + request: index_endpoint_service.UpdateIndexEndpointRequest = None, + *, + index_endpoint: gca_index_endpoint.IndexEndpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_index_endpoint.IndexEndpoint: r"""Updates an IndexEndpoint. Args: @@ -672,8 +689,10 @@ def update_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_endpoint_service.UpdateIndexEndpointRequest. @@ -697,30 +716,26 @@ def update_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index_endpoint.name', request.index_endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index_endpoint.name", request.index_endpoint.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_index_endpoint(self, - request: index_endpoint_service.DeleteIndexEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_index_endpoint( + self, + request: index_endpoint_service.DeleteIndexEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes an IndexEndpoint. Args: @@ -766,8 +781,10 @@ def delete_index_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_endpoint_service.DeleteIndexEndpointRequest. @@ -789,18 +806,11 @@ def delete_index_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -813,15 +823,16 @@ def delete_index_endpoint(self, # Done; return the response. return response - def deploy_index(self, - request: index_endpoint_service.DeployIndexRequest = None, - *, - index_endpoint: str = None, - deployed_index: gca_index_endpoint.DeployedIndex = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def deploy_index( + self, + request: index_endpoint_service.DeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index: gca_index_endpoint.DeployedIndex = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deploys an Index into this IndexEndpoint, creating a DeployedIndex within it. Only non-empty Indexes can be deployed. @@ -867,8 +878,10 @@ def deploy_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_endpoint_service.DeployIndexRequest. @@ -892,18 +905,13 @@ def deploy_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index_endpoint', request.index_endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index_endpoint", request.index_endpoint),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -916,15 +924,16 @@ def deploy_index(self, # Done; return the response. return response - def undeploy_index(self, - request: index_endpoint_service.UndeployIndexRequest = None, - *, - index_endpoint: str = None, - deployed_index_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def undeploy_index( + self, + request: index_endpoint_service.UndeployIndexRequest = None, + *, + index_endpoint: str = None, + deployed_index_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Undeploys an Index from an IndexEndpoint, removing a DeployedIndex from it, and freeing all resources it's using. @@ -970,8 +979,10 @@ def undeploy_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_endpoint_service.UndeployIndexRequest. @@ -995,18 +1006,13 @@ def undeploy_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index_endpoint', request.index_endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index_endpoint", request.index_endpoint),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1020,21 +1026,14 @@ def undeploy_index(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'IndexEndpointServiceClient', -) +__all__ = ("IndexEndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py index 7c38beadfd..ae7b2cdbf9 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import index_endpoint from google.cloud.aiplatform_v1beta1.types import index_endpoint_service @@ -38,12 +47,15 @@ class ListIndexEndpointsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., index_endpoint_service.ListIndexEndpointsResponse], - request: index_endpoint_service.ListIndexEndpointsRequest, - response: index_endpoint_service.ListIndexEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., index_endpoint_service.ListIndexEndpointsResponse], + request: index_endpoint_service.ListIndexEndpointsRequest, + response: index_endpoint_service.ListIndexEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[index_endpoint.IndexEndpoint]: yield from page.index_endpoints def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListIndexEndpointsAsyncPager: @@ -97,12 +109,17 @@ class ListIndexEndpointsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[index_endpoint_service.ListIndexEndpointsResponse]], - request: index_endpoint_service.ListIndexEndpointsRequest, - response: index_endpoint_service.ListIndexEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[index_endpoint_service.ListIndexEndpointsResponse] + ], + request: index_endpoint_service.ListIndexEndpointsRequest, + response: index_endpoint_service.ListIndexEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +141,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[index_endpoint_service.ListIndexEndpointsResponse]: + async def pages( + self, + ) -> AsyncIterable[index_endpoint_service.ListIndexEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +159,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py index dd025dddb8..9ce68726cf 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/__init__.py @@ -24,12 +24,14 @@ # Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[IndexEndpointServiceTransport]] -_transport_registry['grpc'] = IndexEndpointServiceGrpcTransport -_transport_registry['grpc_asyncio'] = IndexEndpointServiceGrpcAsyncIOTransport +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[IndexEndpointServiceTransport]] +_transport_registry["grpc"] = IndexEndpointServiceGrpcTransport +_transport_registry["grpc_asyncio"] = IndexEndpointServiceGrpcAsyncIOTransport __all__ = ( - 'IndexEndpointServiceTransport', - 'IndexEndpointServiceGrpcTransport', - 'IndexEndpointServiceGrpcAsyncIOTransport', + "IndexEndpointServiceTransport", + "IndexEndpointServiceGrpcTransport", + "IndexEndpointServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py index e16f56dd80..5ace621f9b 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,29 +35,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class IndexEndpointServiceTransport(abc.ABC): """Abstract transport class for IndexEndpointService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -80,8 +80,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -90,17 +90,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -114,9 +116,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_index_endpoint: gapic_v1.method.wrap_method( - self.get_index_endpoint, - default_timeout=None, - client_info=client_info, + self.get_index_endpoint, default_timeout=None, client_info=client_info, ), self.list_index_endpoints: gapic_v1.method.wrap_method( self.list_index_endpoints, @@ -134,16 +134,11 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.deploy_index: gapic_v1.method.wrap_method( - self.deploy_index, - default_timeout=None, - client_info=client_info, + self.deploy_index, default_timeout=None, client_info=client_info, ), self.undeploy_index: gapic_v1.method.wrap_method( - self.undeploy_index, - default_timeout=None, - client_info=client_info, + self.undeploy_index, default_timeout=None, client_info=client_info, ), - } @property @@ -152,69 +147,75 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_index_endpoint(self) -> typing.Callable[ - [index_endpoint_service.CreateIndexEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_index_endpoint( + self, + ) -> typing.Callable[ + [index_endpoint_service.CreateIndexEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_index_endpoint(self) -> typing.Callable[ - [index_endpoint_service.GetIndexEndpointRequest], - typing.Union[ - index_endpoint.IndexEndpoint, - typing.Awaitable[index_endpoint.IndexEndpoint] - ]]: + def get_index_endpoint( + self, + ) -> typing.Callable[ + [index_endpoint_service.GetIndexEndpointRequest], + typing.Union[ + index_endpoint.IndexEndpoint, typing.Awaitable[index_endpoint.IndexEndpoint] + ], + ]: raise NotImplementedError() @property - def list_index_endpoints(self) -> typing.Callable[ - [index_endpoint_service.ListIndexEndpointsRequest], - typing.Union[ - index_endpoint_service.ListIndexEndpointsResponse, - typing.Awaitable[index_endpoint_service.ListIndexEndpointsResponse] - ]]: + def list_index_endpoints( + self, + ) -> typing.Callable[ + [index_endpoint_service.ListIndexEndpointsRequest], + typing.Union[ + index_endpoint_service.ListIndexEndpointsResponse, + typing.Awaitable[index_endpoint_service.ListIndexEndpointsResponse], + ], + ]: raise NotImplementedError() @property - def update_index_endpoint(self) -> typing.Callable[ - [index_endpoint_service.UpdateIndexEndpointRequest], - typing.Union[ - gca_index_endpoint.IndexEndpoint, - typing.Awaitable[gca_index_endpoint.IndexEndpoint] - ]]: + def update_index_endpoint( + self, + ) -> typing.Callable[ + [index_endpoint_service.UpdateIndexEndpointRequest], + typing.Union[ + gca_index_endpoint.IndexEndpoint, + typing.Awaitable[gca_index_endpoint.IndexEndpoint], + ], + ]: raise NotImplementedError() @property - def delete_index_endpoint(self) -> typing.Callable[ - [index_endpoint_service.DeleteIndexEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_index_endpoint( + self, + ) -> typing.Callable[ + [index_endpoint_service.DeleteIndexEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def deploy_index(self) -> typing.Callable[ - [index_endpoint_service.DeployIndexRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def deploy_index( + self, + ) -> typing.Callable[ + [index_endpoint_service.DeployIndexRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def undeploy_index(self) -> typing.Callable[ - [index_endpoint_service.UndeployIndexRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def undeploy_index( + self, + ) -> typing.Callable[ + [index_endpoint_service.UndeployIndexRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'IndexEndpointServiceTransport', -) +__all__ = ("IndexEndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py index 274c8cdc6f..a41e483a61 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -47,21 +47,24 @@ class IndexEndpointServiceGrpcTransport(IndexEndpointServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -173,13 +176,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -212,7 +217,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -230,17 +235,17 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_index_endpoint(self) -> Callable[ - [index_endpoint_service.CreateIndexEndpointRequest], - operations.Operation]: + def create_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.CreateIndexEndpointRequest], operations.Operation + ]: r"""Return a callable for the create index endpoint method over gRPC. Creates an IndexEndpoint. @@ -255,18 +260,20 @@ def create_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_index_endpoint' not in self._stubs: - self._stubs['create_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/CreateIndexEndpoint', + if "create_index_endpoint" not in self._stubs: + self._stubs["create_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/CreateIndexEndpoint", request_serializer=index_endpoint_service.CreateIndexEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_index_endpoint'] + return self._stubs["create_index_endpoint"] @property - def get_index_endpoint(self) -> Callable[ - [index_endpoint_service.GetIndexEndpointRequest], - index_endpoint.IndexEndpoint]: + def get_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.GetIndexEndpointRequest], index_endpoint.IndexEndpoint + ]: r"""Return a callable for the get index endpoint method over gRPC. Gets an IndexEndpoint. @@ -281,18 +288,21 @@ def get_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_index_endpoint' not in self._stubs: - self._stubs['get_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/GetIndexEndpoint', + if "get_index_endpoint" not in self._stubs: + self._stubs["get_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/GetIndexEndpoint", request_serializer=index_endpoint_service.GetIndexEndpointRequest.serialize, response_deserializer=index_endpoint.IndexEndpoint.deserialize, ) - return self._stubs['get_index_endpoint'] + return self._stubs["get_index_endpoint"] @property - def list_index_endpoints(self) -> Callable[ - [index_endpoint_service.ListIndexEndpointsRequest], - index_endpoint_service.ListIndexEndpointsResponse]: + def list_index_endpoints( + self, + ) -> Callable[ + [index_endpoint_service.ListIndexEndpointsRequest], + index_endpoint_service.ListIndexEndpointsResponse, + ]: r"""Return a callable for the list index endpoints method over gRPC. Lists IndexEndpoints in a Location. @@ -307,18 +317,21 @@ def list_index_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_index_endpoints' not in self._stubs: - self._stubs['list_index_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/ListIndexEndpoints', + if "list_index_endpoints" not in self._stubs: + self._stubs["list_index_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/ListIndexEndpoints", request_serializer=index_endpoint_service.ListIndexEndpointsRequest.serialize, response_deserializer=index_endpoint_service.ListIndexEndpointsResponse.deserialize, ) - return self._stubs['list_index_endpoints'] + return self._stubs["list_index_endpoints"] @property - def update_index_endpoint(self) -> Callable[ - [index_endpoint_service.UpdateIndexEndpointRequest], - gca_index_endpoint.IndexEndpoint]: + def update_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.UpdateIndexEndpointRequest], + gca_index_endpoint.IndexEndpoint, + ]: r"""Return a callable for the update index endpoint method over gRPC. Updates an IndexEndpoint. @@ -333,18 +346,20 @@ def update_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_index_endpoint' not in self._stubs: - self._stubs['update_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UpdateIndexEndpoint', + if "update_index_endpoint" not in self._stubs: + self._stubs["update_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/UpdateIndexEndpoint", request_serializer=index_endpoint_service.UpdateIndexEndpointRequest.serialize, response_deserializer=gca_index_endpoint.IndexEndpoint.deserialize, ) - return self._stubs['update_index_endpoint'] + return self._stubs["update_index_endpoint"] @property - def delete_index_endpoint(self) -> Callable[ - [index_endpoint_service.DeleteIndexEndpointRequest], - operations.Operation]: + def delete_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.DeleteIndexEndpointRequest], operations.Operation + ]: r"""Return a callable for the delete index endpoint method over gRPC. Deletes an IndexEndpoint. @@ -359,18 +374,18 @@ def delete_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_index_endpoint' not in self._stubs: - self._stubs['delete_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeleteIndexEndpoint', + if "delete_index_endpoint" not in self._stubs: + self._stubs["delete_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeleteIndexEndpoint", request_serializer=index_endpoint_service.DeleteIndexEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_index_endpoint'] + return self._stubs["delete_index_endpoint"] @property - def deploy_index(self) -> Callable[ - [index_endpoint_service.DeployIndexRequest], - operations.Operation]: + def deploy_index( + self, + ) -> Callable[[index_endpoint_service.DeployIndexRequest], operations.Operation]: r"""Return a callable for the deploy index method over gRPC. Deploys an Index into this IndexEndpoint, creating a @@ -387,18 +402,18 @@ def deploy_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_index' not in self._stubs: - self._stubs['deploy_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeployIndex', + if "deploy_index" not in self._stubs: + self._stubs["deploy_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeployIndex", request_serializer=index_endpoint_service.DeployIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_index'] + return self._stubs["deploy_index"] @property - def undeploy_index(self) -> Callable[ - [index_endpoint_service.UndeployIndexRequest], - operations.Operation]: + def undeploy_index( + self, + ) -> Callable[[index_endpoint_service.UndeployIndexRequest], operations.Operation]: r"""Return a callable for the undeploy index method over gRPC. Undeploys an Index from an IndexEndpoint, removing a @@ -415,15 +430,13 @@ def undeploy_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_index' not in self._stubs: - self._stubs['undeploy_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UndeployIndex', + if "undeploy_index" not in self._stubs: + self._stubs["undeploy_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/UndeployIndex", request_serializer=index_endpoint_service.UndeployIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_index'] + return self._stubs["undeploy_index"] -__all__ = ( - 'IndexEndpointServiceGrpcTransport', -) +__all__ = ("IndexEndpointServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py index 3b2c0fb5ce..a34337a84f 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import index_endpoint @@ -54,13 +54,15 @@ class IndexEndpointServiceGrpcAsyncIOTransport(IndexEndpointServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -89,22 +91,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -243,9 +247,12 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_index_endpoint(self) -> Callable[ - [index_endpoint_service.CreateIndexEndpointRequest], - Awaitable[operations.Operation]]: + def create_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.CreateIndexEndpointRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the create index endpoint method over gRPC. Creates an IndexEndpoint. @@ -260,18 +267,21 @@ def create_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_index_endpoint' not in self._stubs: - self._stubs['create_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/CreateIndexEndpoint', + if "create_index_endpoint" not in self._stubs: + self._stubs["create_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/CreateIndexEndpoint", request_serializer=index_endpoint_service.CreateIndexEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_index_endpoint'] + return self._stubs["create_index_endpoint"] @property - def get_index_endpoint(self) -> Callable[ - [index_endpoint_service.GetIndexEndpointRequest], - Awaitable[index_endpoint.IndexEndpoint]]: + def get_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.GetIndexEndpointRequest], + Awaitable[index_endpoint.IndexEndpoint], + ]: r"""Return a callable for the get index endpoint method over gRPC. Gets an IndexEndpoint. @@ -286,18 +296,21 @@ def get_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_index_endpoint' not in self._stubs: - self._stubs['get_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/GetIndexEndpoint', + if "get_index_endpoint" not in self._stubs: + self._stubs["get_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/GetIndexEndpoint", request_serializer=index_endpoint_service.GetIndexEndpointRequest.serialize, response_deserializer=index_endpoint.IndexEndpoint.deserialize, ) - return self._stubs['get_index_endpoint'] + return self._stubs["get_index_endpoint"] @property - def list_index_endpoints(self) -> Callable[ - [index_endpoint_service.ListIndexEndpointsRequest], - Awaitable[index_endpoint_service.ListIndexEndpointsResponse]]: + def list_index_endpoints( + self, + ) -> Callable[ + [index_endpoint_service.ListIndexEndpointsRequest], + Awaitable[index_endpoint_service.ListIndexEndpointsResponse], + ]: r"""Return a callable for the list index endpoints method over gRPC. Lists IndexEndpoints in a Location. @@ -312,18 +325,21 @@ def list_index_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_index_endpoints' not in self._stubs: - self._stubs['list_index_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/ListIndexEndpoints', + if "list_index_endpoints" not in self._stubs: + self._stubs["list_index_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/ListIndexEndpoints", request_serializer=index_endpoint_service.ListIndexEndpointsRequest.serialize, response_deserializer=index_endpoint_service.ListIndexEndpointsResponse.deserialize, ) - return self._stubs['list_index_endpoints'] + return self._stubs["list_index_endpoints"] @property - def update_index_endpoint(self) -> Callable[ - [index_endpoint_service.UpdateIndexEndpointRequest], - Awaitable[gca_index_endpoint.IndexEndpoint]]: + def update_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.UpdateIndexEndpointRequest], + Awaitable[gca_index_endpoint.IndexEndpoint], + ]: r"""Return a callable for the update index endpoint method over gRPC. Updates an IndexEndpoint. @@ -338,18 +354,21 @@ def update_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_index_endpoint' not in self._stubs: - self._stubs['update_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UpdateIndexEndpoint', + if "update_index_endpoint" not in self._stubs: + self._stubs["update_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/UpdateIndexEndpoint", request_serializer=index_endpoint_service.UpdateIndexEndpointRequest.serialize, response_deserializer=gca_index_endpoint.IndexEndpoint.deserialize, ) - return self._stubs['update_index_endpoint'] + return self._stubs["update_index_endpoint"] @property - def delete_index_endpoint(self) -> Callable[ - [index_endpoint_service.DeleteIndexEndpointRequest], - Awaitable[operations.Operation]]: + def delete_index_endpoint( + self, + ) -> Callable[ + [index_endpoint_service.DeleteIndexEndpointRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete index endpoint method over gRPC. Deletes an IndexEndpoint. @@ -364,18 +383,20 @@ def delete_index_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_index_endpoint' not in self._stubs: - self._stubs['delete_index_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeleteIndexEndpoint', + if "delete_index_endpoint" not in self._stubs: + self._stubs["delete_index_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeleteIndexEndpoint", request_serializer=index_endpoint_service.DeleteIndexEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_index_endpoint'] + return self._stubs["delete_index_endpoint"] @property - def deploy_index(self) -> Callable[ - [index_endpoint_service.DeployIndexRequest], - Awaitable[operations.Operation]]: + def deploy_index( + self, + ) -> Callable[ + [index_endpoint_service.DeployIndexRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the deploy index method over gRPC. Deploys an Index into this IndexEndpoint, creating a @@ -392,18 +413,20 @@ def deploy_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_index' not in self._stubs: - self._stubs['deploy_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeployIndex', + if "deploy_index" not in self._stubs: + self._stubs["deploy_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/DeployIndex", request_serializer=index_endpoint_service.DeployIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_index'] + return self._stubs["deploy_index"] @property - def undeploy_index(self) -> Callable[ - [index_endpoint_service.UndeployIndexRequest], - Awaitable[operations.Operation]]: + def undeploy_index( + self, + ) -> Callable[ + [index_endpoint_service.UndeployIndexRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the undeploy index method over gRPC. Undeploys an Index from an IndexEndpoint, removing a @@ -420,15 +443,13 @@ def undeploy_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_index' not in self._stubs: - self._stubs['undeploy_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexEndpointService/UndeployIndex', + if "undeploy_index" not in self._stubs: + self._stubs["undeploy_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexEndpointService/UndeployIndex", request_serializer=index_endpoint_service.UndeployIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_index'] + return self._stubs["undeploy_index"] -__all__ = ( - 'IndexEndpointServiceGrpcAsyncIOTransport', -) +__all__ = ("IndexEndpointServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_service/__init__.py index 5b6569d841..bf9cebd517 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import IndexServiceAsyncClient __all__ = ( - 'IndexServiceClient', - 'IndexServiceAsyncClient', + "IndexServiceClient", + "IndexServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py index 49fc00f568..346bd1bc1e 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -59,22 +59,34 @@ class IndexServiceAsyncClient: index_path = staticmethod(IndexServiceClient.index_path) parse_index_path = staticmethod(IndexServiceClient.parse_index_path) index_endpoint_path = staticmethod(IndexServiceClient.index_endpoint_path) - parse_index_endpoint_path = staticmethod(IndexServiceClient.parse_index_endpoint_path) + parse_index_endpoint_path = staticmethod( + IndexServiceClient.parse_index_endpoint_path + ) - common_billing_account_path = staticmethod(IndexServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(IndexServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + IndexServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + IndexServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(IndexServiceClient.common_folder_path) parse_common_folder_path = staticmethod(IndexServiceClient.parse_common_folder_path) common_organization_path = staticmethod(IndexServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(IndexServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + IndexServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(IndexServiceClient.common_project_path) - parse_common_project_path = staticmethod(IndexServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + IndexServiceClient.parse_common_project_path + ) common_location_path = staticmethod(IndexServiceClient.common_location_path) - parse_common_location_path = staticmethod(IndexServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + IndexServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -117,14 +129,18 @@ def transport(self) -> IndexServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(IndexServiceClient).get_transport_class, type(IndexServiceClient)) + get_transport_class = functools.partial( + type(IndexServiceClient).get_transport_class, type(IndexServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, IndexServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, IndexServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the index service client. Args: @@ -163,18 +179,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_index(self, - request: index_service.CreateIndexRequest = None, - *, - parent: str = None, - index: gca_index.Index = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_index( + self, + request: index_service.CreateIndexRequest = None, + *, + parent: str = None, + index: gca_index.Index = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates an Index. Args: @@ -215,8 +231,10 @@ async def create_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_service.CreateIndexRequest(request) @@ -239,18 +257,11 @@ async def create_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -263,14 +274,15 @@ async def create_index(self, # Done; return the response. return response - async def get_index(self, - request: index_service.GetIndexRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> index.Index: + async def get_index( + self, + request: index_service.GetIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index.Index: r"""Gets an Index. Args: @@ -304,8 +316,10 @@ async def get_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_service.GetIndexRequest(request) @@ -326,30 +340,24 @@ async def get_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_indexes(self, - request: index_service.ListIndexesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListIndexesAsyncPager: + async def list_indexes( + self, + request: index_service.ListIndexesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexesAsyncPager: r"""Lists Indexes in a Location. Args: @@ -385,8 +393,10 @@ async def list_indexes(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_service.ListIndexesRequest(request) @@ -407,40 +417,31 @@ async def list_indexes(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListIndexesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_index(self, - request: index_service.UpdateIndexRequest = None, - *, - index: gca_index.Index = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_index( + self, + request: index_service.UpdateIndexRequest = None, + *, + index: gca_index.Index = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates an Index. Args: @@ -483,8 +484,10 @@ async def update_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_service.UpdateIndexRequest(request) @@ -507,18 +510,13 @@ async def update_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index.name', request.index.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index.name", request.index.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -531,14 +529,15 @@ async def update_index(self, # Done; return the response. return response - async def delete_index(self, - request: index_service.DeleteIndexRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_index( + self, + request: index_service.DeleteIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes an Index. An Index can only be deleted when all its ``DeployedIndexes`` had been undeployed. @@ -586,8 +585,10 @@ async def delete_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = index_service.DeleteIndexRequest(request) @@ -608,18 +609,11 @@ async def delete_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -633,21 +627,14 @@ async def delete_index(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'IndexServiceAsyncClient', -) +__all__ = ("IndexServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_service/client.py index 133cf63a94..b90771f405 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -57,13 +57,12 @@ class IndexServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[IndexServiceTransport]] - _transport_registry['grpc'] = IndexServiceGrpcTransport - _transport_registry['grpc_asyncio'] = IndexServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = IndexServiceGrpcTransport + _transport_registry["grpc_asyncio"] = IndexServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[IndexServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[IndexServiceTransport]: """Return an appropriate transport class. Args: @@ -116,7 +115,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -151,9 +150,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: IndexServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -168,88 +166,104 @@ def transport(self) -> IndexServiceTransport: return self._transport @staticmethod - def index_path(project: str,location: str,index: str,) -> str: + def index_path(project: str, location: str, index: str,) -> str: """Return a fully-qualified index string.""" - return "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + return "projects/{project}/locations/{location}/indexes/{index}".format( + project=project, location=location, index=index, + ) @staticmethod - def parse_index_path(path: str) -> Dict[str,str]: + def parse_index_path(path: str) -> Dict[str, str]: """Parse a index path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexes/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/indexes/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def index_endpoint_path(project: str,location: str,index_endpoint: str,) -> str: + def index_endpoint_path(project: str, location: str, index_endpoint: str,) -> str: """Return a fully-qualified index_endpoint string.""" - return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( + project=project, location=location, index_endpoint=index_endpoint, + ) @staticmethod - def parse_index_endpoint_path(path: str) -> Dict[str,str]: + def parse_index_endpoint_path(path: str) -> Dict[str, str]: """Parse a index_endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, IndexServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, IndexServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the index service client. Args: @@ -293,7 +307,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -303,7 +319,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -315,7 +333,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -327,8 +347,10 @@ def __init__(self, *, if isinstance(transport, IndexServiceTransport): # transport is a IndexServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -347,15 +369,16 @@ def __init__(self, *, client_info=client_info, ) - def create_index(self, - request: index_service.CreateIndexRequest = None, - *, - parent: str = None, - index: gca_index.Index = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_index( + self, + request: index_service.CreateIndexRequest = None, + *, + parent: str = None, + index: gca_index.Index = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates an Index. Args: @@ -396,8 +419,10 @@ def create_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_service.CreateIndexRequest. @@ -421,18 +446,11 @@ def create_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -445,14 +463,15 @@ def create_index(self, # Done; return the response. return response - def get_index(self, - request: index_service.GetIndexRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> index.Index: + def get_index( + self, + request: index_service.GetIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> index.Index: r"""Gets an Index. Args: @@ -486,8 +505,10 @@ def get_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_service.GetIndexRequest. @@ -509,30 +530,24 @@ def get_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_indexes(self, - request: index_service.ListIndexesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListIndexesPager: + def list_indexes( + self, + request: index_service.ListIndexesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListIndexesPager: r"""Lists Indexes in a Location. Args: @@ -568,8 +583,10 @@ def list_indexes(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_service.ListIndexesRequest. @@ -591,40 +608,31 @@ def list_indexes(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListIndexesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_index(self, - request: index_service.UpdateIndexRequest = None, - *, - index: gca_index.Index = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def update_index( + self, + request: index_service.UpdateIndexRequest = None, + *, + index: gca_index.Index = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Updates an Index. Args: @@ -667,8 +675,10 @@ def update_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_service.UpdateIndexRequest. @@ -692,18 +702,13 @@ def update_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('index.name', request.index.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("index.name", request.index.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -716,14 +721,15 @@ def update_index(self, # Done; return the response. return response - def delete_index(self, - request: index_service.DeleteIndexRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_index( + self, + request: index_service.DeleteIndexRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes an Index. An Index can only be deleted when all its ``DeployedIndexes`` had been undeployed. @@ -771,8 +777,10 @@ def delete_index(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a index_service.DeleteIndexRequest. @@ -794,18 +802,11 @@ def delete_index(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -819,21 +820,14 @@ def delete_index(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'IndexServiceClient', -) +__all__ = ("IndexServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py index dea7e37830..18b3cea2f7 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import index from google.cloud.aiplatform_v1beta1.types import index_service @@ -38,12 +47,15 @@ class ListIndexesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., index_service.ListIndexesResponse], - request: index_service.ListIndexesRequest, - response: index_service.ListIndexesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., index_service.ListIndexesResponse], + request: index_service.ListIndexesRequest, + response: index_service.ListIndexesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +89,7 @@ def __iter__(self) -> Iterable[index.Index]: yield from page.indexes def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListIndexesAsyncPager: @@ -97,12 +109,15 @@ class ListIndexesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[index_service.ListIndexesResponse]], - request: index_service.ListIndexesRequest, - response: index_service.ListIndexesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[index_service.ListIndexesResponse]], + request: index_service.ListIndexesRequest, + response: index_service.ListIndexesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -140,4 +155,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py index 7bb2e2abad..f9345ef29c 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/__init__.py @@ -25,11 +25,11 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[IndexServiceTransport]] -_transport_registry['grpc'] = IndexServiceGrpcTransport -_transport_registry['grpc_asyncio'] = IndexServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = IndexServiceGrpcTransport +_transport_registry["grpc_asyncio"] = IndexServiceGrpcAsyncIOTransport __all__ = ( - 'IndexServiceTransport', - 'IndexServiceGrpcTransport', - 'IndexServiceGrpcAsyncIOTransport', + "IndexServiceTransport", + "IndexServiceGrpcTransport", + "IndexServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py index fd218d13dd..9c5e4d4538 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,29 +34,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class IndexServiceTransport(abc.ABC): """Abstract transport class for IndexService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -79,8 +79,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -89,17 +89,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -108,31 +110,20 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_index: gapic_v1.method.wrap_method( - self.create_index, - default_timeout=None, - client_info=client_info, + self.create_index, default_timeout=None, client_info=client_info, ), self.get_index: gapic_v1.method.wrap_method( - self.get_index, - default_timeout=None, - client_info=client_info, + self.get_index, default_timeout=None, client_info=client_info, ), self.list_indexes: gapic_v1.method.wrap_method( - self.list_indexes, - default_timeout=None, - client_info=client_info, + self.list_indexes, default_timeout=None, client_info=client_info, ), self.update_index: gapic_v1.method.wrap_method( - self.update_index, - default_timeout=None, - client_info=client_info, + self.update_index, default_timeout=None, client_info=client_info, ), self.delete_index: gapic_v1.method.wrap_method( - self.delete_index, - default_timeout=None, - client_info=client_info, + self.delete_index, default_timeout=None, client_info=client_info, ), - } @property @@ -141,51 +132,52 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_index(self) -> typing.Callable[ - [index_service.CreateIndexRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_index( + self, + ) -> typing.Callable[ + [index_service.CreateIndexRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_index(self) -> typing.Callable[ - [index_service.GetIndexRequest], - typing.Union[ - index.Index, - typing.Awaitable[index.Index] - ]]: + def get_index( + self, + ) -> typing.Callable[ + [index_service.GetIndexRequest], + typing.Union[index.Index, typing.Awaitable[index.Index]], + ]: raise NotImplementedError() @property - def list_indexes(self) -> typing.Callable[ - [index_service.ListIndexesRequest], - typing.Union[ - index_service.ListIndexesResponse, - typing.Awaitable[index_service.ListIndexesResponse] - ]]: + def list_indexes( + self, + ) -> typing.Callable[ + [index_service.ListIndexesRequest], + typing.Union[ + index_service.ListIndexesResponse, + typing.Awaitable[index_service.ListIndexesResponse], + ], + ]: raise NotImplementedError() @property - def update_index(self) -> typing.Callable[ - [index_service.UpdateIndexRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def update_index( + self, + ) -> typing.Callable[ + [index_service.UpdateIndexRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def delete_index(self) -> typing.Callable[ - [index_service.DeleteIndexRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_index( + self, + ) -> typing.Callable[ + [index_service.DeleteIndexRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'IndexServiceTransport', -) +__all__ = ("IndexServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py index 783ab5733f..85f15724f2 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -47,21 +47,24 @@ class IndexServiceGrpcTransport(IndexServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -173,13 +176,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -212,7 +217,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -230,17 +235,15 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Sanity check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def create_index(self) -> Callable[ - [index_service.CreateIndexRequest], - operations.Operation]: + def create_index( + self, + ) -> Callable[[index_service.CreateIndexRequest], operations.Operation]: r"""Return a callable for the create index method over gRPC. Creates an Index. @@ -255,18 +258,16 @@ def create_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_index' not in self._stubs: - self._stubs['create_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/CreateIndex', + if "create_index" not in self._stubs: + self._stubs["create_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/CreateIndex", request_serializer=index_service.CreateIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_index'] + return self._stubs["create_index"] @property - def get_index(self) -> Callable[ - [index_service.GetIndexRequest], - index.Index]: + def get_index(self) -> Callable[[index_service.GetIndexRequest], index.Index]: r"""Return a callable for the get index method over gRPC. Gets an Index. @@ -281,18 +282,20 @@ def get_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_index' not in self._stubs: - self._stubs['get_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/GetIndex', + if "get_index" not in self._stubs: + self._stubs["get_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/GetIndex", request_serializer=index_service.GetIndexRequest.serialize, response_deserializer=index.Index.deserialize, ) - return self._stubs['get_index'] + return self._stubs["get_index"] @property - def list_indexes(self) -> Callable[ - [index_service.ListIndexesRequest], - index_service.ListIndexesResponse]: + def list_indexes( + self, + ) -> Callable[ + [index_service.ListIndexesRequest], index_service.ListIndexesResponse + ]: r"""Return a callable for the list indexes method over gRPC. Lists Indexes in a Location. @@ -307,18 +310,18 @@ def list_indexes(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_indexes' not in self._stubs: - self._stubs['list_indexes'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/ListIndexes', + if "list_indexes" not in self._stubs: + self._stubs["list_indexes"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/ListIndexes", request_serializer=index_service.ListIndexesRequest.serialize, response_deserializer=index_service.ListIndexesResponse.deserialize, ) - return self._stubs['list_indexes'] + return self._stubs["list_indexes"] @property - def update_index(self) -> Callable[ - [index_service.UpdateIndexRequest], - operations.Operation]: + def update_index( + self, + ) -> Callable[[index_service.UpdateIndexRequest], operations.Operation]: r"""Return a callable for the update index method over gRPC. Updates an Index. @@ -333,18 +336,18 @@ def update_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_index' not in self._stubs: - self._stubs['update_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/UpdateIndex', + if "update_index" not in self._stubs: + self._stubs["update_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/UpdateIndex", request_serializer=index_service.UpdateIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_index'] + return self._stubs["update_index"] @property - def delete_index(self) -> Callable[ - [index_service.DeleteIndexRequest], - operations.Operation]: + def delete_index( + self, + ) -> Callable[[index_service.DeleteIndexRequest], operations.Operation]: r"""Return a callable for the delete index method over gRPC. Deletes an Index. An Index can only be deleted when all its @@ -361,15 +364,13 @@ def delete_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_index' not in self._stubs: - self._stubs['delete_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/DeleteIndex', + if "delete_index" not in self._stubs: + self._stubs["delete_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/DeleteIndex", request_serializer=index_service.DeleteIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_index'] + return self._stubs["delete_index"] -__all__ = ( - 'IndexServiceGrpcTransport', -) +__all__ = ("IndexServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py index e0287ff613..2eb6cbc633 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import index @@ -54,13 +54,15 @@ class IndexServiceGrpcAsyncIOTransport(IndexServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -89,22 +91,24 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -243,9 +247,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def create_index(self) -> Callable[ - [index_service.CreateIndexRequest], - Awaitable[operations.Operation]]: + def create_index( + self, + ) -> Callable[[index_service.CreateIndexRequest], Awaitable[operations.Operation]]: r"""Return a callable for the create index method over gRPC. Creates an Index. @@ -260,18 +264,18 @@ def create_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_index' not in self._stubs: - self._stubs['create_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/CreateIndex', + if "create_index" not in self._stubs: + self._stubs["create_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/CreateIndex", request_serializer=index_service.CreateIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_index'] + return self._stubs["create_index"] @property - def get_index(self) -> Callable[ - [index_service.GetIndexRequest], - Awaitable[index.Index]]: + def get_index( + self, + ) -> Callable[[index_service.GetIndexRequest], Awaitable[index.Index]]: r"""Return a callable for the get index method over gRPC. Gets an Index. @@ -286,18 +290,20 @@ def get_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_index' not in self._stubs: - self._stubs['get_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/GetIndex', + if "get_index" not in self._stubs: + self._stubs["get_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/GetIndex", request_serializer=index_service.GetIndexRequest.serialize, response_deserializer=index.Index.deserialize, ) - return self._stubs['get_index'] + return self._stubs["get_index"] @property - def list_indexes(self) -> Callable[ - [index_service.ListIndexesRequest], - Awaitable[index_service.ListIndexesResponse]]: + def list_indexes( + self, + ) -> Callable[ + [index_service.ListIndexesRequest], Awaitable[index_service.ListIndexesResponse] + ]: r"""Return a callable for the list indexes method over gRPC. Lists Indexes in a Location. @@ -312,18 +318,18 @@ def list_indexes(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_indexes' not in self._stubs: - self._stubs['list_indexes'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/ListIndexes', + if "list_indexes" not in self._stubs: + self._stubs["list_indexes"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/ListIndexes", request_serializer=index_service.ListIndexesRequest.serialize, response_deserializer=index_service.ListIndexesResponse.deserialize, ) - return self._stubs['list_indexes'] + return self._stubs["list_indexes"] @property - def update_index(self) -> Callable[ - [index_service.UpdateIndexRequest], - Awaitable[operations.Operation]]: + def update_index( + self, + ) -> Callable[[index_service.UpdateIndexRequest], Awaitable[operations.Operation]]: r"""Return a callable for the update index method over gRPC. Updates an Index. @@ -338,18 +344,18 @@ def update_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_index' not in self._stubs: - self._stubs['update_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/UpdateIndex', + if "update_index" not in self._stubs: + self._stubs["update_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/UpdateIndex", request_serializer=index_service.UpdateIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_index'] + return self._stubs["update_index"] @property - def delete_index(self) -> Callable[ - [index_service.DeleteIndexRequest], - Awaitable[operations.Operation]]: + def delete_index( + self, + ) -> Callable[[index_service.DeleteIndexRequest], Awaitable[operations.Operation]]: r"""Return a callable for the delete index method over gRPC. Deletes an Index. An Index can only be deleted when all its @@ -366,15 +372,13 @@ def delete_index(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_index' not in self._stubs: - self._stubs['delete_index'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.IndexService/DeleteIndex', + if "delete_index" not in self._stubs: + self._stubs["delete_index"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.IndexService/DeleteIndex", request_serializer=index_service.DeleteIndexRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_index'] + return self._stubs["delete_index"] -__all__ = ( - 'IndexServiceGrpcAsyncIOTransport', -) +__all__ = ("IndexServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index de6c880e58..6f649532af 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -23,36 +23,44 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job -from google.cloud.aiplatform_v1beta1.types import model_deployment_monitoring_job as gca_model_deployment_monitoring_job +from google.cloud.aiplatform_v1beta1.types import ( + model_deployment_monitoring_job as gca_model_deployment_monitoring_job, +) from google.cloud.aiplatform_v1beta1.types import model_monitoring from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import study @@ -76,13 +84,12 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry['grpc'] = JobServiceGrpcTransport - _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[JobServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -133,7 +140,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -168,9 +175,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: JobServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -185,165 +191,230 @@ def transport(self) -> JobServiceTransport: return self._transport @staticmethod - def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: + def batch_prediction_job_path( + project: str, location: str, batch_prediction_job: str, + ) -> str: """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, + location=location, + batch_prediction_job=batch_prediction_job, + ) @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: + def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: """Parse a batch_prediction_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def custom_job_path(project: str,location: str,custom_job: str,) -> str: + def custom_job_path(project: str, location: str, custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str,str]: + def parse_custom_job_path(path: str) -> Dict[str, str]: """Parse a custom_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: + def data_labeling_job_path( + project: str, location: str, data_labeling_job: str, + ) -> str: """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str,str]: + def parse_data_labeling_job_path(path: str) -> Dict[str, str]: """Parse a data_labeling_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: + def hyperparameter_tuning_job_path( + project: str, location: str, hyperparameter_tuning_job: str, + ) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_deployment_monitoring_job_path(project: str,location: str,model_deployment_monitoring_job: str,) -> str: + def model_deployment_monitoring_job_path( + project: str, location: str, model_deployment_monitoring_job: str, + ) -> str: """Return a fully-qualified model_deployment_monitoring_job string.""" - return "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format(project=project, location=location, model_deployment_monitoring_job=model_deployment_monitoring_job, ) + return "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format( + project=project, + location=location, + model_deployment_monitoring_job=model_deployment_monitoring_job, + ) @staticmethod - def parse_model_deployment_monitoring_job_path(path: str) -> Dict[str,str]: + def parse_model_deployment_monitoring_job_path(path: str) -> Dict[str, str]: """Parse a model_deployment_monitoring_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/modelDeploymentMonitoringJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/modelDeploymentMonitoringJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def trial_path(project: str,location: str,study: str,trial: str,) -> str: + def trial_path(project: str, location: str, study: str, trial: str,) -> str: """Return a fully-qualified trial string.""" - return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format(project=project, location=location, study=study, trial=trial, ) + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) @staticmethod - def parse_trial_path(path: str) -> Dict[str,str]: + def parse_trial_path(path: str) -> Dict[str, str]: """Parse a trial path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -387,7 +458,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -397,7 +470,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -409,7 +484,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -421,8 +498,10 @@ def __init__(self, *, if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -441,15 +520,16 @@ def __init__(self, *, client_info=client_info, ) - def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -494,8 +574,10 @@ def create_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateCustomJobRequest. @@ -519,30 +601,24 @@ def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -580,8 +656,10 @@ def get_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetCustomJobRequest. @@ -603,30 +681,24 @@ def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -662,8 +734,10 @@ def list_custom_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListCustomJobsRequest. @@ -685,39 +759,30 @@ def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a CustomJob. Args: @@ -763,8 +828,10 @@ def delete_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteCustomJobRequest. @@ -786,18 +853,11 @@ def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -810,14 +870,15 @@ def delete_custom_job(self, # Done; return the response. return response - def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -855,8 +916,10 @@ def cancel_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelCustomJobRequest. @@ -878,28 +941,24 @@ def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -939,8 +998,10 @@ def create_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateDataLabelingJobRequest. @@ -964,30 +1025,24 @@ def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -1020,8 +1075,10 @@ def get_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetDataLabelingJobRequest. @@ -1043,30 +1100,24 @@ def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -1101,8 +1152,10 @@ def list_data_labeling_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListDataLabelingJobsRequest. @@ -1124,39 +1177,30 @@ def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1202,8 +1246,10 @@ def delete_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteDataLabelingJobRequest. @@ -1225,18 +1271,11 @@ def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1249,14 +1288,15 @@ def delete_data_labeling_job(self, # Done; return the response. return response - def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1283,8 +1323,10 @@ def cancel_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelDataLabelingJobRequest. @@ -1306,28 +1348,24 @@ def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1369,8 +1407,10 @@ def create_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateHyperparameterTuningJobRequest. @@ -1389,35 +1429,31 @@ def create_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1452,8 +1488,10 @@ def get_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetHyperparameterTuningJobRequest. @@ -1470,35 +1508,31 @@ def get_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1534,8 +1568,10 @@ def list_hyperparameter_tuning_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListHyperparameterTuningJobsRequest. @@ -1552,44 +1588,37 @@ def list_hyperparameter_tuning_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1635,8 +1664,10 @@ def delete_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteHyperparameterTuningJobRequest. @@ -1653,23 +1684,18 @@ def delete_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1682,14 +1708,15 @@ def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1729,8 +1756,10 @@ def cancel_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelHyperparameterTuningJobRequest. @@ -1747,33 +1776,31 @@ def cancel_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1818,8 +1845,10 @@ def create_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateBatchPredictionJobRequest. @@ -1838,35 +1867,31 @@ def create_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1903,8 +1928,10 @@ def get_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetBatchPredictionJobRequest. @@ -1926,30 +1953,24 @@ def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1985,8 +2006,10 @@ def list_batch_prediction_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListBatchPredictionJobsRequest. @@ -2003,44 +2026,37 @@ def list_batch_prediction_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2087,8 +2103,10 @@ def delete_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteBatchPredictionJobRequest. @@ -2105,23 +2123,18 @@ def delete_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -2134,14 +2147,15 @@ def delete_batch_prediction_job(self, # Done; return the response. return response - def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -2179,8 +2193,10 @@ def cancel_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelBatchPredictionJobRequest. @@ -2197,33 +2213,31 @@ def cancel_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def create_model_deployment_monitoring_job(self, - request: job_service.CreateModelDeploymentMonitoringJobRequest = None, - *, - parent: str = None, - model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_model_deployment_monitoring_job( + self, + request: job_service.CreateModelDeploymentMonitoringJobRequest = None, + *, + parent: str = None, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob: r"""Creates a ModelDeploymentMonitoringJob. It will run periodically on a configured interval. @@ -2267,14 +2281,18 @@ def create_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.CreateModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.CreateModelDeploymentMonitoringJobRequest + ): request = job_service.CreateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2283,40 +2301,38 @@ def create_model_deployment_monitoring_job(self, if parent is not None: request.parent = parent if model_deployment_monitoring_job is not None: - request.model_deployment_monitoring_job = model_deployment_monitoring_job + request.model_deployment_monitoring_job = ( + model_deployment_monitoring_job + ) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def search_model_deployment_monitoring_stats_anomalies(self, - request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, - *, - model_deployment_monitoring_job: str = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager: + def search_model_deployment_monitoring_stats_anomalies( + self, + request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest = None, + *, + model_deployment_monitoring_job: str = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager: r"""Searches Model Monitoring Statistics generated within a given time window. @@ -2360,64 +2376,72 @@ def search_model_deployment_monitoring_stats_anomalies(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest): - request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) + if not isinstance( + request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest + ): + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. if model_deployment_monitoring_job is not None: - request.model_deployment_monitoring_job = model_deployment_monitoring_job + request.model_deployment_monitoring_job = ( + model_deployment_monitoring_job + ) if deployed_model_id is not None: request.deployed_model_id = deployed_model_id # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.search_model_deployment_monitoring_stats_anomalies] + rpc = self._transport._wrapped_methods[ + self._transport.search_model_deployment_monitoring_stats_anomalies + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model_deployment_monitoring_job', request.model_deployment_monitoring_job), - )), + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "model_deployment_monitoring_job", + request.model_deployment_monitoring_job, + ), + ) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_model_deployment_monitoring_job(self, - request: job_service.GetModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: + def get_model_deployment_monitoring_job( + self, + request: job_service.GetModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_deployment_monitoring_job.ModelDeploymentMonitoringJob: r"""Gets a ModelDeploymentMonitoringJob. Args: @@ -2453,8 +2477,10 @@ def get_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetModelDeploymentMonitoringJobRequest. @@ -2471,35 +2497,31 @@ def get_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.get_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_deployment_monitoring_jobs(self, - request: job_service.ListModelDeploymentMonitoringJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelDeploymentMonitoringJobsPager: + def list_model_deployment_monitoring_jobs( + self, + request: job_service.ListModelDeploymentMonitoringJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelDeploymentMonitoringJobsPager: r"""Lists ModelDeploymentMonitoringJobs in a Location. Args: @@ -2535,14 +2557,18 @@ def list_model_deployment_monitoring_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListModelDeploymentMonitoringJobsRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.ListModelDeploymentMonitoringJobsRequest): + if not isinstance( + request, job_service.ListModelDeploymentMonitoringJobsRequest + ): request = job_service.ListModelDeploymentMonitoringJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2553,45 +2579,38 @@ def list_model_deployment_monitoring_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_deployment_monitoring_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_model_deployment_monitoring_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelDeploymentMonitoringJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_model_deployment_monitoring_job(self, - request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, - *, - model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def update_model_deployment_monitoring_job( + self, + request: job_service.UpdateModelDeploymentMonitoringJobRequest = None, + *, + model_deployment_monitoring_job: gca_model_deployment_monitoring_job.ModelDeploymentMonitoringJob = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Updates a ModelDeploymentMonitoringJob. Args: @@ -2634,43 +2653,51 @@ def update_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.UpdateModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.UpdateModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.UpdateModelDeploymentMonitoringJobRequest + ): request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. if model_deployment_monitoring_job is not None: - request.model_deployment_monitoring_job = model_deployment_monitoring_job + request.model_deployment_monitoring_job = ( + model_deployment_monitoring_job + ) if update_mask is not None: request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.update_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.update_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model_deployment_monitoring_job.name', request.model_deployment_monitoring_job.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "model_deployment_monitoring_job.name", + request.model_deployment_monitoring_job.name, + ), + ) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -2683,14 +2710,15 @@ def update_model_deployment_monitoring_job(self, # Done; return the response. return response - def delete_model_deployment_monitoring_job(self, - request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_model_deployment_monitoring_job( + self, + request: job_service.DeleteModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a ModelDeploymentMonitoringJob. Args: @@ -2736,14 +2764,18 @@ def delete_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.DeleteModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.DeleteModelDeploymentMonitoringJobRequest + ): request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2754,23 +2786,18 @@ def delete_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -2783,14 +2810,15 @@ def delete_model_deployment_monitoring_job(self, # Done; return the response. return response - def pause_model_deployment_monitoring_job(self, - request: job_service.PauseModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def pause_model_deployment_monitoring_job( + self, + request: job_service.PauseModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Pauses a ModelDeploymentMonitoringJob. If the job is running, the server makes a best effort to cancel the job. Will mark ``ModelDeploymentMonitoringJob.state`` @@ -2820,14 +2848,18 @@ def pause_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.PauseModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.PauseModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.PauseModelDeploymentMonitoringJobRequest + ): request = job_service.PauseModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2838,32 +2870,30 @@ def pause_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.pause_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.pause_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - def resume_model_deployment_monitoring_job(self, - request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def resume_model_deployment_monitoring_job( + self, + request: job_service.ResumeModelDeploymentMonitoringJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Resumes a paused ModelDeploymentMonitoringJob. It will start to run from next scheduled time. A deleted ModelDeploymentMonitoringJob can't be resumed. @@ -2892,14 +2922,18 @@ def resume_model_deployment_monitoring_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ResumeModelDeploymentMonitoringJobRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, job_service.ResumeModelDeploymentMonitoringJobRequest): + if not isinstance( + request, job_service.ResumeModelDeploymentMonitoringJobRequest + ): request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2910,40 +2944,30 @@ def resume_model_deployment_monitoring_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.resume_model_deployment_monitoring_job] + rpc = self._transport._wrapped_methods[ + self._transport.resume_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceClient', -) +__all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index 4cbb5b5aeb..8e5a390f2c 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -2382,14 +2382,15 @@ async def list_metadata_schemas( # Done; return the response. return response - async def query_artifact_lineage_subgraph(self, - request: metadata_service.QueryArtifactLineageSubgraphRequest = None, - *, - artifact: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + async def query_artifact_lineage_subgraph( + self, + request: metadata_service.QueryArtifactLineageSubgraphRequest = None, + *, + artifact: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Retrieves lineage of an Artifact represented through Artifacts and Executions connected by Event edges and returned as a LineageSubgraph. @@ -2431,8 +2432,10 @@ async def query_artifact_lineage_subgraph(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = metadata_service.QueryArtifactLineageSubgraphRequest(request) @@ -2453,18 +2456,11 @@ async def query_artifact_lineage_subgraph(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('artifact', request.artifact), - )), + gapic_v1.routing_header.to_grpc_metadata((("artifact", request.artifact),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index 6983d6e5fd..85f65bfe9c 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -67,13 +67,14 @@ class MetadataServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[MetadataServiceTransport]] - _transport_registry['grpc'] = MetadataServiceGrpcTransport - _transport_registry['grpc_asyncio'] = MetadataServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[MetadataServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MetadataServiceTransport]] + _transport_registry["grpc"] = MetadataServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MetadataServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MetadataServiceTransport]: """Return an appropriate transport class. Args: @@ -124,7 +125,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -159,9 +160,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MetadataServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -176,121 +176,172 @@ def transport(self) -> MetadataServiceTransport: return self._transport @staticmethod - def artifact_path(project: str,location: str,metadata_store: str,artifact: str,) -> str: + def artifact_path( + project: str, location: str, metadata_store: str, artifact: str, + ) -> str: """Return a fully-qualified artifact string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format(project=project, location=location, metadata_store=metadata_store, artifact=artifact, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) @staticmethod - def parse_artifact_path(path: str) -> Dict[str,str]: + def parse_artifact_path(path: str) -> Dict[str, str]: """Parse a artifact path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/artifacts/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/artifacts/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def context_path(project: str,location: str,metadata_store: str,context: str,) -> str: + def context_path( + project: str, location: str, metadata_store: str, context: str, + ) -> str: """Return a fully-qualified context string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format(project=project, location=location, metadata_store=metadata_store, context=context, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) @staticmethod - def parse_context_path(path: str) -> Dict[str,str]: + def parse_context_path(path: str) -> Dict[str, str]: """Parse a context path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def execution_path(project: str,location: str,metadata_store: str,execution: str,) -> str: + def execution_path( + project: str, location: str, metadata_store: str, execution: str, + ) -> str: """Return a fully-qualified execution string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format(project=project, location=location, metadata_store=metadata_store, execution=execution, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) @staticmethod - def parse_execution_path(path: str) -> Dict[str,str]: + def parse_execution_path(path: str) -> Dict[str, str]: """Parse a execution path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/executions/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/executions/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def metadata_schema_path(project: str,location: str,metadata_store: str,metadata_schema: str,) -> str: + def metadata_schema_path( + project: str, location: str, metadata_store: str, metadata_schema: str, + ) -> str: """Return a fully-qualified metadata_schema string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format(project=project, location=location, metadata_store=metadata_store, metadata_schema=metadata_schema, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format( + project=project, + location=location, + metadata_store=metadata_store, + metadata_schema=metadata_schema, + ) @staticmethod - def parse_metadata_schema_path(path: str) -> Dict[str,str]: + def parse_metadata_schema_path(path: str) -> Dict[str, str]: """Parse a metadata_schema path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/metadataSchemas/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/metadataSchemas/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def metadata_store_path(project: str,location: str,metadata_store: str,) -> str: + def metadata_store_path(project: str, location: str, metadata_store: str,) -> str: """Return a fully-qualified metadata_store string.""" - return "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format(project=project, location=location, metadata_store=metadata_store, ) + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format( + project=project, location=location, metadata_store=metadata_store, + ) @staticmethod - def parse_metadata_store_path(path: str) -> Dict[str,str]: + def parse_metadata_store_path(path: str) -> Dict[str, str]: """Parse a metadata_store path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MetadataServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MetadataServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the metadata service client. Args: @@ -334,7 +385,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -344,7 +397,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -356,7 +411,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -368,8 +425,10 @@ def __init__(self, *, if isinstance(transport, MetadataServiceTransport): # transport is a MetadataServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -388,16 +447,17 @@ def __init__(self, *, client_info=client_info, ) - def create_metadata_store(self, - request: metadata_service.CreateMetadataStoreRequest = None, - *, - parent: str = None, - metadata_store: gca_metadata_store.MetadataStore = None, - metadata_store_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_metadata_store( + self, + request: metadata_service.CreateMetadataStoreRequest = None, + *, + parent: str = None, + metadata_store: gca_metadata_store.MetadataStore = None, + metadata_store_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Initializes a MetadataStore, including allocation of resources. @@ -456,8 +516,10 @@ def create_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateMetadataStoreRequest. @@ -483,18 +545,11 @@ def create_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -507,14 +562,15 @@ def create_metadata_store(self, # Done; return the response. return response - def get_metadata_store(self, - request: metadata_service.GetMetadataStoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_store.MetadataStore: + def get_metadata_store( + self, + request: metadata_service.GetMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_store.MetadataStore: r"""Retrieves a specific MetadataStore. Args: @@ -548,8 +604,10 @@ def get_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetMetadataStoreRequest. @@ -571,30 +629,24 @@ def get_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_metadata_stores(self, - request: metadata_service.ListMetadataStoresRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListMetadataStoresPager: + def list_metadata_stores( + self, + request: metadata_service.ListMetadataStoresRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataStoresPager: r"""Lists MetadataStores for a Location. Args: @@ -630,8 +682,10 @@ def list_metadata_stores(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListMetadataStoresRequest. @@ -653,39 +707,30 @@ def list_metadata_stores(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListMetadataStoresPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_metadata_store(self, - request: metadata_service.DeleteMetadataStoreRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_metadata_store( + self, + request: metadata_service.DeleteMetadataStoreRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a single MetadataStore. Args: @@ -731,8 +776,10 @@ def delete_metadata_store(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.DeleteMetadataStoreRequest. @@ -754,18 +801,11 @@ def delete_metadata_store(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -778,16 +818,17 @@ def delete_metadata_store(self, # Done; return the response. return response - def create_artifact(self, - request: metadata_service.CreateArtifactRequest = None, - *, - parent: str = None, - artifact: gca_artifact.Artifact = None, - artifact_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_artifact.Artifact: + def create_artifact( + self, + request: metadata_service.CreateArtifactRequest = None, + *, + parent: str = None, + artifact: gca_artifact.Artifact = None, + artifact_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: r"""Creates an Artifact associated with a MetadataStore. Args: @@ -839,8 +880,10 @@ def create_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateArtifactRequest. @@ -866,30 +909,24 @@ def create_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_artifact(self, - request: metadata_service.GetArtifactRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> artifact.Artifact: + def get_artifact( + self, + request: metadata_service.GetArtifactRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> artifact.Artifact: r"""Retrieves a specific Artifact. Args: @@ -920,8 +957,10 @@ def get_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetArtifactRequest. @@ -943,30 +982,24 @@ def get_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_artifacts(self, - request: metadata_service.ListArtifactsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListArtifactsPager: + def list_artifacts( + self, + request: metadata_service.ListArtifactsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListArtifactsPager: r"""Lists Artifacts in the MetadataStore. Args: @@ -1002,8 +1035,10 @@ def list_artifacts(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListArtifactsRequest. @@ -1025,40 +1060,31 @@ def list_artifacts(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListArtifactsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_artifact(self, - request: metadata_service.UpdateArtifactRequest = None, - *, - artifact: gca_artifact.Artifact = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_artifact.Artifact: + def update_artifact( + self, + request: metadata_service.UpdateArtifactRequest = None, + *, + artifact: gca_artifact.Artifact = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_artifact.Artifact: r"""Updates a stored Artifact. Args: @@ -1099,8 +1125,10 @@ def update_artifact(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.UpdateArtifactRequest. @@ -1124,32 +1152,28 @@ def update_artifact(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('artifact.name', request.artifact.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("artifact.name", request.artifact.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def create_context(self, - request: metadata_service.CreateContextRequest = None, - *, - parent: str = None, - context: gca_context.Context = None, - context_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_context.Context: + def create_context( + self, + request: metadata_service.CreateContextRequest = None, + *, + parent: str = None, + context: gca_context.Context = None, + context_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: r"""Creates a Context associated with a MetadataStore. Args: @@ -1201,8 +1225,10 @@ def create_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateContextRequest. @@ -1228,30 +1254,24 @@ def create_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_context(self, - request: metadata_service.GetContextRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> context.Context: + def get_context( + self, + request: metadata_service.GetContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> context.Context: r"""Retrieves a specific Context. Args: @@ -1282,8 +1302,10 @@ def get_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetContextRequest. @@ -1305,30 +1327,24 @@ def get_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_contexts(self, - request: metadata_service.ListContextsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListContextsPager: + def list_contexts( + self, + request: metadata_service.ListContextsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListContextsPager: r"""Lists Contexts on the MetadataStore. Args: @@ -1364,8 +1380,10 @@ def list_contexts(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListContextsRequest. @@ -1387,40 +1405,31 @@ def list_contexts(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListContextsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_context(self, - request: metadata_service.UpdateContextRequest = None, - *, - context: gca_context.Context = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_context.Context: + def update_context( + self, + request: metadata_service.UpdateContextRequest = None, + *, + context: gca_context.Context = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_context.Context: r"""Updates a stored Context. Args: @@ -1460,8 +1469,10 @@ def update_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.UpdateContextRequest. @@ -1485,30 +1496,26 @@ def update_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context.name', request.context.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("context.name", request.context.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_context(self, - request: metadata_service.DeleteContextRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_context( + self, + request: metadata_service.DeleteContextRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a stored Context. Args: @@ -1554,8 +1561,10 @@ def delete_context(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.DeleteContextRequest. @@ -1577,18 +1586,11 @@ def delete_context(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -1601,16 +1603,17 @@ def delete_context(self, # Done; return the response. return response - def add_context_artifacts_and_executions(self, - request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, - *, - context: str = None, - artifacts: Sequence[str] = None, - executions: Sequence[str] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: + def add_context_artifacts_and_executions( + self, + request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, + *, + context: str = None, + artifacts: Sequence[str] = None, + executions: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextArtifactsAndExecutionsResponse: r"""Adds a set of Artifacts and Executions to a Context. If any of the Artifacts or Executions have already been added to a Context, they are simply skipped. @@ -1660,14 +1663,18 @@ def add_context_artifacts_and_executions(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.AddContextArtifactsAndExecutionsRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, metadata_service.AddContextArtifactsAndExecutionsRequest): + if not isinstance( + request, metadata_service.AddContextArtifactsAndExecutionsRequest + ): request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1682,36 +1689,32 @@ def add_context_artifacts_and_executions(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.add_context_artifacts_and_executions] + rpc = self._transport._wrapped_methods[ + self._transport.add_context_artifacts_and_executions + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def add_context_children(self, - request: metadata_service.AddContextChildrenRequest = None, - *, - context: str = None, - child_contexts: Sequence[str] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddContextChildrenResponse: + def add_context_children( + self, + request: metadata_service.AddContextChildrenRequest = None, + *, + context: str = None, + child_contexts: Sequence[str] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddContextChildrenResponse: r"""Adds a set of Contexts as children to a parent Context. If any of the child Contexts have already been added to the parent Context, they are simply skipped. If this call would create a @@ -1755,8 +1758,10 @@ def add_context_children(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.AddContextChildrenRequest. @@ -1780,30 +1785,24 @@ def add_context_children(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def query_context_lineage_subgraph(self, - request: metadata_service.QueryContextLineageSubgraphRequest = None, - *, - context: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + def query_context_lineage_subgraph( + self, + request: metadata_service.QueryContextLineageSubgraphRequest = None, + *, + context: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Retrieves Artifacts and Executions within the specified Context, connected by Event edges and returned as a LineageSubgraph. @@ -1845,8 +1844,10 @@ def query_context_lineage_subgraph(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.QueryContextLineageSubgraphRequest. @@ -1863,37 +1864,33 @@ def query_context_lineage_subgraph(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.query_context_lineage_subgraph] + rpc = self._transport._wrapped_methods[ + self._transport.query_context_lineage_subgraph + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('context', request.context), - )), + gapic_v1.routing_header.to_grpc_metadata((("context", request.context),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def create_execution(self, - request: metadata_service.CreateExecutionRequest = None, - *, - parent: str = None, - execution: gca_execution.Execution = None, - execution_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_execution.Execution: + def create_execution( + self, + request: metadata_service.CreateExecutionRequest = None, + *, + parent: str = None, + execution: gca_execution.Execution = None, + execution_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: r"""Creates an Execution associated with a MetadataStore. Args: @@ -1945,8 +1942,10 @@ def create_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateExecutionRequest. @@ -1972,30 +1971,24 @@ def create_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_execution(self, - request: metadata_service.GetExecutionRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> execution.Execution: + def get_execution( + self, + request: metadata_service.GetExecutionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> execution.Execution: r"""Retrieves a specific Execution. Args: @@ -2026,8 +2019,10 @@ def get_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetExecutionRequest. @@ -2049,30 +2044,24 @@ def get_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_executions(self, - request: metadata_service.ListExecutionsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListExecutionsPager: + def list_executions( + self, + request: metadata_service.ListExecutionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListExecutionsPager: r"""Lists Executions in the MetadataStore. Args: @@ -2108,8 +2097,10 @@ def list_executions(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListExecutionsRequest. @@ -2131,40 +2122,31 @@ def list_executions(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListExecutionsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_execution(self, - request: metadata_service.UpdateExecutionRequest = None, - *, - execution: gca_execution.Execution = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_execution.Execution: + def update_execution( + self, + request: metadata_service.UpdateExecutionRequest = None, + *, + execution: gca_execution.Execution = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_execution.Execution: r"""Updates a stored Execution. Args: @@ -2205,8 +2187,10 @@ def update_execution(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.UpdateExecutionRequest. @@ -2230,31 +2214,27 @@ def update_execution(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution.name', request.execution.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution.name", request.execution.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def add_execution_events(self, - request: metadata_service.AddExecutionEventsRequest = None, - *, - execution: str = None, - events: Sequence[event.Event] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_service.AddExecutionEventsResponse: + def add_execution_events( + self, + request: metadata_service.AddExecutionEventsRequest = None, + *, + execution: str = None, + events: Sequence[event.Event] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_service.AddExecutionEventsResponse: r"""Adds Events for denoting whether each Artifact was an input or output for a given Execution. If any Events already exist between the Execution and any of the @@ -2296,8 +2276,10 @@ def add_execution_events(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.AddExecutionEventsRequest. @@ -2321,30 +2303,26 @@ def add_execution_events(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution', request.execution), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution", request.execution),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def query_execution_inputs_and_outputs(self, - request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, - *, - execution: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + def query_execution_inputs_and_outputs( + self, + request: metadata_service.QueryExecutionInputsAndOutputsRequest = None, + *, + execution: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Obtains the set of input and output Artifacts for this Execution, in the form of LineageSubgraph that also contains the Execution and connecting Events. @@ -2382,14 +2360,18 @@ def query_execution_inputs_and_outputs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.QueryExecutionInputsAndOutputsRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, metadata_service.QueryExecutionInputsAndOutputsRequest): + if not isinstance( + request, metadata_service.QueryExecutionInputsAndOutputsRequest + ): request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2400,37 +2382,35 @@ def query_execution_inputs_and_outputs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.query_execution_inputs_and_outputs] + rpc = self._transport._wrapped_methods[ + self._transport.query_execution_inputs_and_outputs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('execution', request.execution), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("execution", request.execution),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def create_metadata_schema(self, - request: metadata_service.CreateMetadataSchemaRequest = None, - *, - parent: str = None, - metadata_schema: gca_metadata_schema.MetadataSchema = None, - metadata_schema_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_metadata_schema.MetadataSchema: + def create_metadata_schema( + self, + request: metadata_service.CreateMetadataSchemaRequest = None, + *, + parent: str = None, + metadata_schema: gca_metadata_schema.MetadataSchema = None, + metadata_schema_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_metadata_schema.MetadataSchema: r"""Creates an MetadataSchema. Args: @@ -2484,8 +2464,10 @@ def create_metadata_schema(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.CreateMetadataSchemaRequest. @@ -2511,30 +2493,24 @@ def create_metadata_schema(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_metadata_schema(self, - request: metadata_service.GetMetadataSchemaRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> metadata_schema.MetadataSchema: + def get_metadata_schema( + self, + request: metadata_service.GetMetadataSchemaRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> metadata_schema.MetadataSchema: r"""Retrieves a specific MetadataSchema. Args: @@ -2565,8 +2541,10 @@ def get_metadata_schema(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.GetMetadataSchemaRequest. @@ -2588,30 +2566,24 @@ def get_metadata_schema(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_metadata_schemas(self, - request: metadata_service.ListMetadataSchemasRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListMetadataSchemasPager: + def list_metadata_schemas( + self, + request: metadata_service.ListMetadataSchemasRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListMetadataSchemasPager: r"""Lists MetadataSchemas. Args: @@ -2648,8 +2620,10 @@ def list_metadata_schemas(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.ListMetadataSchemasRequest. @@ -2671,39 +2645,30 @@ def list_metadata_schemas(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListMetadataSchemasPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def query_artifact_lineage_subgraph(self, - request: metadata_service.QueryArtifactLineageSubgraphRequest = None, - *, - artifact: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> lineage_subgraph.LineageSubgraph: + def query_artifact_lineage_subgraph( + self, + request: metadata_service.QueryArtifactLineageSubgraphRequest = None, + *, + artifact: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> lineage_subgraph.LineageSubgraph: r"""Retrieves lineage of an Artifact represented through Artifacts and Executions connected by Event edges and returned as a LineageSubgraph. @@ -2745,14 +2710,18 @@ def query_artifact_lineage_subgraph(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a metadata_service.QueryArtifactLineageSubgraphRequest. # There's no risk of modifying the input as we've already verified # there are no flattened fields. - if not isinstance(request, metadata_service.QueryArtifactLineageSubgraphRequest): + if not isinstance( + request, metadata_service.QueryArtifactLineageSubgraphRequest + ): request = metadata_service.QueryArtifactLineageSubgraphRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2763,43 +2732,31 @@ def query_artifact_lineage_subgraph(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.query_artifact_lineage_subgraph] + rpc = self._transport._wrapped_methods[ + self._transport.query_artifact_lineage_subgraph + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('artifact', request.artifact), - )), + gapic_v1.routing_header.to_grpc_metadata((("artifact", request.artifact),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MetadataServiceClient', -) +__all__ = ("MetadataServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py index f4acfb6800..4131483c16 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -43,29 +43,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class MetadataServiceTransport(abc.ABC): """Abstract transport class for MetadataService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -88,8 +88,8 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # Save the scopes. @@ -98,17 +98,19 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=self._scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -122,9 +124,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_metadata_store: gapic_v1.method.wrap_method( - self.get_metadata_store, - default_timeout=None, - client_info=client_info, + self.get_metadata_store, default_timeout=None, client_info=client_info, ), self.list_metadata_stores: gapic_v1.method.wrap_method( self.list_metadata_stores, @@ -137,49 +137,31 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.create_artifact: gapic_v1.method.wrap_method( - self.create_artifact, - default_timeout=None, - client_info=client_info, + self.create_artifact, default_timeout=None, client_info=client_info, ), self.get_artifact: gapic_v1.method.wrap_method( - self.get_artifact, - default_timeout=None, - client_info=client_info, + self.get_artifact, default_timeout=None, client_info=client_info, ), self.list_artifacts: gapic_v1.method.wrap_method( - self.list_artifacts, - default_timeout=None, - client_info=client_info, + self.list_artifacts, default_timeout=None, client_info=client_info, ), self.update_artifact: gapic_v1.method.wrap_method( - self.update_artifact, - default_timeout=None, - client_info=client_info, + self.update_artifact, default_timeout=None, client_info=client_info, ), self.create_context: gapic_v1.method.wrap_method( - self.create_context, - default_timeout=None, - client_info=client_info, + self.create_context, default_timeout=None, client_info=client_info, ), self.get_context: gapic_v1.method.wrap_method( - self.get_context, - default_timeout=None, - client_info=client_info, + self.get_context, default_timeout=None, client_info=client_info, ), self.list_contexts: gapic_v1.method.wrap_method( - self.list_contexts, - default_timeout=None, - client_info=client_info, + self.list_contexts, default_timeout=None, client_info=client_info, ), self.update_context: gapic_v1.method.wrap_method( - self.update_context, - default_timeout=None, - client_info=client_info, + self.update_context, default_timeout=None, client_info=client_info, ), self.delete_context: gapic_v1.method.wrap_method( - self.delete_context, - default_timeout=None, - client_info=client_info, + self.delete_context, default_timeout=None, client_info=client_info, ), self.add_context_artifacts_and_executions: gapic_v1.method.wrap_method( self.add_context_artifacts_and_executions, @@ -197,24 +179,16 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.create_execution: gapic_v1.method.wrap_method( - self.create_execution, - default_timeout=None, - client_info=client_info, + self.create_execution, default_timeout=None, client_info=client_info, ), self.get_execution: gapic_v1.method.wrap_method( - self.get_execution, - default_timeout=None, - client_info=client_info, + self.get_execution, default_timeout=None, client_info=client_info, ), self.list_executions: gapic_v1.method.wrap_method( - self.list_executions, - default_timeout=None, - client_info=client_info, + self.list_executions, default_timeout=None, client_info=client_info, ), self.update_execution: gapic_v1.method.wrap_method( - self.update_execution, - default_timeout=None, - client_info=client_info, + self.update_execution, default_timeout=None, client_info=client_info, ), self.add_execution_events: gapic_v1.method.wrap_method( self.add_execution_events, @@ -232,9 +206,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_metadata_schema: gapic_v1.method.wrap_method( - self.get_metadata_schema, - default_timeout=None, - client_info=client_info, + self.get_metadata_schema, default_timeout=None, client_info=client_info, ), self.list_metadata_schemas: gapic_v1.method.wrap_method( self.list_metadata_schemas, @@ -246,7 +218,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -255,240 +226,283 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_metadata_store(self) -> typing.Callable[ - [metadata_service.CreateMetadataStoreRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_metadata_store( + self, + ) -> typing.Callable[ + [metadata_service.CreateMetadataStoreRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_metadata_store(self) -> typing.Callable[ - [metadata_service.GetMetadataStoreRequest], - typing.Union[ - metadata_store.MetadataStore, - typing.Awaitable[metadata_store.MetadataStore] - ]]: + def get_metadata_store( + self, + ) -> typing.Callable[ + [metadata_service.GetMetadataStoreRequest], + typing.Union[ + metadata_store.MetadataStore, typing.Awaitable[metadata_store.MetadataStore] + ], + ]: raise NotImplementedError() @property - def list_metadata_stores(self) -> typing.Callable[ - [metadata_service.ListMetadataStoresRequest], - typing.Union[ - metadata_service.ListMetadataStoresResponse, - typing.Awaitable[metadata_service.ListMetadataStoresResponse] - ]]: + def list_metadata_stores( + self, + ) -> typing.Callable[ + [metadata_service.ListMetadataStoresRequest], + typing.Union[ + metadata_service.ListMetadataStoresResponse, + typing.Awaitable[metadata_service.ListMetadataStoresResponse], + ], + ]: raise NotImplementedError() @property - def delete_metadata_store(self) -> typing.Callable[ - [metadata_service.DeleteMetadataStoreRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_metadata_store( + self, + ) -> typing.Callable[ + [metadata_service.DeleteMetadataStoreRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def create_artifact(self) -> typing.Callable[ - [metadata_service.CreateArtifactRequest], - typing.Union[ - gca_artifact.Artifact, - typing.Awaitable[gca_artifact.Artifact] - ]]: + def create_artifact( + self, + ) -> typing.Callable[ + [metadata_service.CreateArtifactRequest], + typing.Union[gca_artifact.Artifact, typing.Awaitable[gca_artifact.Artifact]], + ]: raise NotImplementedError() @property - def get_artifact(self) -> typing.Callable[ - [metadata_service.GetArtifactRequest], - typing.Union[ - artifact.Artifact, - typing.Awaitable[artifact.Artifact] - ]]: + def get_artifact( + self, + ) -> typing.Callable[ + [metadata_service.GetArtifactRequest], + typing.Union[artifact.Artifact, typing.Awaitable[artifact.Artifact]], + ]: raise NotImplementedError() @property - def list_artifacts(self) -> typing.Callable[ - [metadata_service.ListArtifactsRequest], - typing.Union[ - metadata_service.ListArtifactsResponse, - typing.Awaitable[metadata_service.ListArtifactsResponse] - ]]: + def list_artifacts( + self, + ) -> typing.Callable[ + [metadata_service.ListArtifactsRequest], + typing.Union[ + metadata_service.ListArtifactsResponse, + typing.Awaitable[metadata_service.ListArtifactsResponse], + ], + ]: raise NotImplementedError() @property - def update_artifact(self) -> typing.Callable[ - [metadata_service.UpdateArtifactRequest], - typing.Union[ - gca_artifact.Artifact, - typing.Awaitable[gca_artifact.Artifact] - ]]: + def update_artifact( + self, + ) -> typing.Callable[ + [metadata_service.UpdateArtifactRequest], + typing.Union[gca_artifact.Artifact, typing.Awaitable[gca_artifact.Artifact]], + ]: raise NotImplementedError() @property - def create_context(self) -> typing.Callable[ - [metadata_service.CreateContextRequest], - typing.Union[ - gca_context.Context, - typing.Awaitable[gca_context.Context] - ]]: + def create_context( + self, + ) -> typing.Callable[ + [metadata_service.CreateContextRequest], + typing.Union[gca_context.Context, typing.Awaitable[gca_context.Context]], + ]: raise NotImplementedError() @property - def get_context(self) -> typing.Callable[ - [metadata_service.GetContextRequest], - typing.Union[ - context.Context, - typing.Awaitable[context.Context] - ]]: + def get_context( + self, + ) -> typing.Callable[ + [metadata_service.GetContextRequest], + typing.Union[context.Context, typing.Awaitable[context.Context]], + ]: raise NotImplementedError() @property - def list_contexts(self) -> typing.Callable[ - [metadata_service.ListContextsRequest], - typing.Union[ - metadata_service.ListContextsResponse, - typing.Awaitable[metadata_service.ListContextsResponse] - ]]: + def list_contexts( + self, + ) -> typing.Callable[ + [metadata_service.ListContextsRequest], + typing.Union[ + metadata_service.ListContextsResponse, + typing.Awaitable[metadata_service.ListContextsResponse], + ], + ]: raise NotImplementedError() @property - def update_context(self) -> typing.Callable[ - [metadata_service.UpdateContextRequest], - typing.Union[ - gca_context.Context, - typing.Awaitable[gca_context.Context] - ]]: + def update_context( + self, + ) -> typing.Callable[ + [metadata_service.UpdateContextRequest], + typing.Union[gca_context.Context, typing.Awaitable[gca_context.Context]], + ]: raise NotImplementedError() @property - def delete_context(self) -> typing.Callable[ - [metadata_service.DeleteContextRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_context( + self, + ) -> typing.Callable[ + [metadata_service.DeleteContextRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def add_context_artifacts_and_executions(self) -> typing.Callable[ - [metadata_service.AddContextArtifactsAndExecutionsRequest], - typing.Union[ - metadata_service.AddContextArtifactsAndExecutionsResponse, - typing.Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse] - ]]: + def add_context_artifacts_and_executions( + self, + ) -> typing.Callable[ + [metadata_service.AddContextArtifactsAndExecutionsRequest], + typing.Union[ + metadata_service.AddContextArtifactsAndExecutionsResponse, + typing.Awaitable[metadata_service.AddContextArtifactsAndExecutionsResponse], + ], + ]: raise NotImplementedError() @property - def add_context_children(self) -> typing.Callable[ - [metadata_service.AddContextChildrenRequest], - typing.Union[ - metadata_service.AddContextChildrenResponse, - typing.Awaitable[metadata_service.AddContextChildrenResponse] - ]]: + def add_context_children( + self, + ) -> typing.Callable[ + [metadata_service.AddContextChildrenRequest], + typing.Union[ + metadata_service.AddContextChildrenResponse, + typing.Awaitable[metadata_service.AddContextChildrenResponse], + ], + ]: raise NotImplementedError() @property - def query_context_lineage_subgraph(self) -> typing.Callable[ - [metadata_service.QueryContextLineageSubgraphRequest], - typing.Union[ - lineage_subgraph.LineageSubgraph, - typing.Awaitable[lineage_subgraph.LineageSubgraph] - ]]: + def query_context_lineage_subgraph( + self, + ) -> typing.Callable[ + [metadata_service.QueryContextLineageSubgraphRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph], + ], + ]: raise NotImplementedError() @property - def create_execution(self) -> typing.Callable[ - [metadata_service.CreateExecutionRequest], - typing.Union[ - gca_execution.Execution, - typing.Awaitable[gca_execution.Execution] - ]]: + def create_execution( + self, + ) -> typing.Callable[ + [metadata_service.CreateExecutionRequest], + typing.Union[ + gca_execution.Execution, typing.Awaitable[gca_execution.Execution] + ], + ]: raise NotImplementedError() @property - def get_execution(self) -> typing.Callable[ - [metadata_service.GetExecutionRequest], - typing.Union[ - execution.Execution, - typing.Awaitable[execution.Execution] - ]]: + def get_execution( + self, + ) -> typing.Callable[ + [metadata_service.GetExecutionRequest], + typing.Union[execution.Execution, typing.Awaitable[execution.Execution]], + ]: raise NotImplementedError() @property - def list_executions(self) -> typing.Callable[ - [metadata_service.ListExecutionsRequest], - typing.Union[ - metadata_service.ListExecutionsResponse, - typing.Awaitable[metadata_service.ListExecutionsResponse] - ]]: + def list_executions( + self, + ) -> typing.Callable[ + [metadata_service.ListExecutionsRequest], + typing.Union[ + metadata_service.ListExecutionsResponse, + typing.Awaitable[metadata_service.ListExecutionsResponse], + ], + ]: raise NotImplementedError() @property - def update_execution(self) -> typing.Callable[ - [metadata_service.UpdateExecutionRequest], - typing.Union[ - gca_execution.Execution, - typing.Awaitable[gca_execution.Execution] - ]]: + def update_execution( + self, + ) -> typing.Callable[ + [metadata_service.UpdateExecutionRequest], + typing.Union[ + gca_execution.Execution, typing.Awaitable[gca_execution.Execution] + ], + ]: raise NotImplementedError() @property - def add_execution_events(self) -> typing.Callable[ - [metadata_service.AddExecutionEventsRequest], - typing.Union[ - metadata_service.AddExecutionEventsResponse, - typing.Awaitable[metadata_service.AddExecutionEventsResponse] - ]]: + def add_execution_events( + self, + ) -> typing.Callable[ + [metadata_service.AddExecutionEventsRequest], + typing.Union[ + metadata_service.AddExecutionEventsResponse, + typing.Awaitable[metadata_service.AddExecutionEventsResponse], + ], + ]: raise NotImplementedError() @property - def query_execution_inputs_and_outputs(self) -> typing.Callable[ - [metadata_service.QueryExecutionInputsAndOutputsRequest], - typing.Union[ - lineage_subgraph.LineageSubgraph, - typing.Awaitable[lineage_subgraph.LineageSubgraph] - ]]: + def query_execution_inputs_and_outputs( + self, + ) -> typing.Callable[ + [metadata_service.QueryExecutionInputsAndOutputsRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph], + ], + ]: raise NotImplementedError() @property - def create_metadata_schema(self) -> typing.Callable[ - [metadata_service.CreateMetadataSchemaRequest], - typing.Union[ - gca_metadata_schema.MetadataSchema, - typing.Awaitable[gca_metadata_schema.MetadataSchema] - ]]: + def create_metadata_schema( + self, + ) -> typing.Callable[ + [metadata_service.CreateMetadataSchemaRequest], + typing.Union[ + gca_metadata_schema.MetadataSchema, + typing.Awaitable[gca_metadata_schema.MetadataSchema], + ], + ]: raise NotImplementedError() @property - def get_metadata_schema(self) -> typing.Callable[ - [metadata_service.GetMetadataSchemaRequest], - typing.Union[ - metadata_schema.MetadataSchema, - typing.Awaitable[metadata_schema.MetadataSchema] - ]]: + def get_metadata_schema( + self, + ) -> typing.Callable[ + [metadata_service.GetMetadataSchemaRequest], + typing.Union[ + metadata_schema.MetadataSchema, + typing.Awaitable[metadata_schema.MetadataSchema], + ], + ]: raise NotImplementedError() @property - def list_metadata_schemas(self) -> typing.Callable[ - [metadata_service.ListMetadataSchemasRequest], - typing.Union[ - metadata_service.ListMetadataSchemasResponse, - typing.Awaitable[metadata_service.ListMetadataSchemasResponse] - ]]: + def list_metadata_schemas( + self, + ) -> typing.Callable[ + [metadata_service.ListMetadataSchemasRequest], + typing.Union[ + metadata_service.ListMetadataSchemasResponse, + typing.Awaitable[metadata_service.ListMetadataSchemasResponse], + ], + ]: raise NotImplementedError() @property - def query_artifact_lineage_subgraph(self) -> typing.Callable[ - [metadata_service.QueryArtifactLineageSubgraphRequest], - typing.Union[ - lineage_subgraph.LineageSubgraph, - typing.Awaitable[lineage_subgraph.LineageSubgraph] - ]]: + def query_artifact_lineage_subgraph( + self, + ) -> typing.Callable[ + [metadata_service.QueryArtifactLineageSubgraphRequest], + typing.Union[ + lineage_subgraph.LineageSubgraph, + typing.Awaitable[lineage_subgraph.LineageSubgraph], + ], + ]: raise NotImplementedError() -__all__ = ( - 'MetadataServiceTransport', -) +__all__ = ("MetadataServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py index fc9a790674..2ae1992f1b 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py @@ -956,9 +956,12 @@ def list_metadata_schemas( return self._stubs["list_metadata_schemas"] @property - def query_artifact_lineage_subgraph(self) -> Callable[ - [metadata_service.QueryArtifactLineageSubgraphRequest], - lineage_subgraph.LineageSubgraph]: + def query_artifact_lineage_subgraph( + self, + ) -> Callable[ + [metadata_service.QueryArtifactLineageSubgraphRequest], + lineage_subgraph.LineageSubgraph, + ]: r"""Return a callable for the query artifact lineage subgraph method over gRPC. @@ -976,13 +979,15 @@ def query_artifact_lineage_subgraph(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'query_artifact_lineage_subgraph' not in self._stubs: - self._stubs['query_artifact_lineage_subgraph'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/QueryArtifactLineageSubgraph', + if "query_artifact_lineage_subgraph" not in self._stubs: + self._stubs[ + "query_artifact_lineage_subgraph" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/QueryArtifactLineageSubgraph", request_serializer=metadata_service.QueryArtifactLineageSubgraphRequest.serialize, response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, ) - return self._stubs['query_artifact_lineage_subgraph'] + return self._stubs["query_artifact_lineage_subgraph"] __all__ = ("MetadataServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py index d7e9feceec..2cd00db999 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py @@ -986,9 +986,12 @@ def list_metadata_schemas( return self._stubs["list_metadata_schemas"] @property - def query_artifact_lineage_subgraph(self) -> Callable[ - [metadata_service.QueryArtifactLineageSubgraphRequest], - Awaitable[lineage_subgraph.LineageSubgraph]]: + def query_artifact_lineage_subgraph( + self, + ) -> Callable[ + [metadata_service.QueryArtifactLineageSubgraphRequest], + Awaitable[lineage_subgraph.LineageSubgraph], + ]: r"""Return a callable for the query artifact lineage subgraph method over gRPC. @@ -1006,13 +1009,15 @@ def query_artifact_lineage_subgraph(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'query_artifact_lineage_subgraph' not in self._stubs: - self._stubs['query_artifact_lineage_subgraph'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MetadataService/QueryArtifactLineageSubgraph', + if "query_artifact_lineage_subgraph" not in self._stubs: + self._stubs[ + "query_artifact_lineage_subgraph" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/QueryArtifactLineageSubgraph", request_serializer=metadata_service.QueryArtifactLineageSubgraphRequest.serialize, response_deserializer=lineage_subgraph.LineageSubgraph.deserialize, ) - return self._stubs['query_artifact_lineage_subgraph'] + return self._stubs["query_artifact_lineage_subgraph"] __all__ = ("MetadataServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 3d57aa5c1f..d4324e3089 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -50,13 +50,14 @@ class MigrationServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] - _transport_registry['grpc'] = MigrationServiceGrpcTransport - _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[MigrationServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry["grpc"] = MigrationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: """Return an appropriate transport class. Args: @@ -110,7 +111,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -145,9 +146,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: MigrationServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -162,143 +162,183 @@ def transport(self) -> MigrationServiceTransport: return self._transport @staticmethod - def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: + def annotated_dataset_path( + project: str, dataset: str, annotated_dataset: str, + ) -> str: """Return a fully-qualified annotated_dataset string.""" - return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) @staticmethod - def parse_annotated_dataset_path(path: str) -> Dict[str,str]: + def parse_annotated_dataset_path(path: str) -> Dict[str, str]: """Parse a annotated_dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def version_path(project: str,model: str,version: str,) -> str: + def version_path(project: str, model: str, version: str,) -> str: """Return a fully-qualified version string.""" - return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + return "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) @staticmethod - def parse_version_path(path: str) -> Dict[str,str]: + def parse_version_path(path: str) -> Dict[str, str]: """Parse a version path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -342,7 +382,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -352,7 +394,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -364,7 +408,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -376,8 +422,10 @@ def __init__(self, *, if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -396,14 +444,15 @@ def __init__(self, *, client_info=client_info, ) - def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesPager: + def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -444,8 +493,10 @@ def search_migratable_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.SearchMigratableResourcesRequest. @@ -462,45 +513,40 @@ def search_migratable_resources(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] + rpc = self._transport._wrapped_methods[ + self._transport.search_migratable_resources + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchMigratableResourcesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -549,8 +595,10 @@ def batch_migrate_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.BatchMigrateResourcesRequest. @@ -574,18 +622,11 @@ def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation.from_gapic( @@ -599,21 +640,14 @@ def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceClient', -) +__all__ = ("MigrationServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 224f714816..1724541446 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -61,13 +61,12 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry['grpc'] = ModelServiceGrpcTransport - _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[ModelServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -118,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -153,9 +152,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: ModelServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -170,121 +168,162 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: + def model_evaluation_path( + project: str, location: str, model: str, evaluation: str, + ) -> str: """Return a fully-qualified model_evaluation string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) @staticmethod - def parse_model_evaluation_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_path(path: str) -> Dict[str, str]: """Parse a model_evaluation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: + def model_evaluation_slice_path( + project: str, location: str, model: str, evaluation: str, slice: str, + ) -> str: """Return a fully-qualified model_evaluation_slice string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) @staticmethod - def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: """Parse a model_evaluation_slice path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -328,7 +367,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -338,7 +379,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -350,7 +393,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -362,8 +407,10 @@ def __init__(self, *, if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -382,15 +429,16 @@ def __init__(self, *, client_info=client_info, ) - def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -433,8 +481,10 @@ def upload_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UploadModelRequest. @@ -458,18 +508,11 @@ def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -482,14 +525,15 @@ def upload_model(self, # Done; return the response. return response - def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -519,8 +563,10 @@ def get_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -542,30 +588,24 @@ def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -601,8 +641,10 @@ def list_models(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -624,40 +666,31 @@ def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -695,8 +728,10 @@ def update_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateModelRequest. @@ -720,30 +755,26 @@ def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -791,8 +822,10 @@ def delete_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteModelRequest. @@ -814,18 +847,11 @@ def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -838,15 +864,16 @@ def delete_model(self, # Done; return the response. return response - def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -894,8 +921,10 @@ def export_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ExportModelRequest. @@ -919,18 +948,11 @@ def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -943,14 +965,15 @@ def export_model(self, # Done; return the response. return response - def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -985,8 +1008,10 @@ def get_model_evaluation(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationRequest. @@ -1008,30 +1033,24 @@ def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -1067,8 +1086,10 @@ def list_model_evaluations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationsRequest. @@ -1090,39 +1111,30 @@ def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -1157,8 +1169,10 @@ def get_model_evaluation_slice(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationSliceRequest. @@ -1175,35 +1189,31 @@ def get_model_evaluation_slice(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] + rpc = self._transport._wrapped_methods[ + self._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1239,8 +1249,10 @@ def list_model_evaluation_slices(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationSlicesRequest. @@ -1257,52 +1269,37 @@ def list_model_evaluation_slices(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] + rpc = self._transport._wrapped_methods[ + self._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceClient', -) +__all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index aa99b2c0c3..9f61aff314 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -41,7 +41,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -59,13 +61,14 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry['grpc'] = PipelineServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PipelineServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry["grpc"] = PipelineServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -116,7 +119,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -151,9 +154,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PipelineServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -168,99 +170,122 @@ def transport(self) -> PipelineServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -304,7 +329,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -314,7 +341,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -326,7 +355,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -338,8 +369,10 @@ def __init__(self, *, if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -358,15 +391,16 @@ def __init__(self, *, client_info=client_info, ) - def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -411,8 +445,10 @@ def create_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CreateTrainingPipelineRequest. @@ -436,30 +472,24 @@ def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -496,8 +526,10 @@ def get_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.GetTrainingPipelineRequest. @@ -519,30 +551,24 @@ def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -578,8 +604,10 @@ def list_training_pipelines(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.ListTrainingPipelinesRequest. @@ -601,39 +629,30 @@ def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -679,8 +698,10 @@ def delete_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.DeleteTrainingPipelineRequest. @@ -702,18 +723,11 @@ def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -726,14 +740,15 @@ def delete_training_pipeline(self, # Done; return the response. return response - def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -772,8 +787,10 @@ def cancel_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CancelTrainingPipelineRequest. @@ -795,35 +812,23 @@ def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceClient', -) +__all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index caa4f9aa26..e13c4ce456 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -54,13 +54,16 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport - _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport + _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +120,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -152,9 +155,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: SpecialistPoolServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -169,77 +171,88 @@ def transport(self) -> SpecialistPoolServiceTransport: return self._transport @staticmethod - def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: + def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str,str]: + def parse_specialist_pool_path(path: str) -> Dict[str, str]: """Parse a specialist_pool path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -283,7 +296,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) client_cert_source_func = None is_mtls = False @@ -293,7 +308,9 @@ def __init__(self, *, client_cert_source_func = client_options.client_cert_source else: is_mtls = mtls.has_default_client_cert_source() - client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -305,7 +322,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -317,8 +336,10 @@ def __init__(self, *, if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -337,15 +358,16 @@ def __init__(self, *, client_info=client_info, ) - def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -393,8 +415,10 @@ def create_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.CreateSpecialistPoolRequest. @@ -418,18 +442,11 @@ def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -442,14 +459,15 @@ def create_specialist_pool(self, # Done; return the response. return response - def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -491,8 +509,10 @@ def get_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.GetSpecialistPoolRequest. @@ -514,30 +534,24 @@ def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -573,8 +587,10 @@ def list_specialist_pools(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.ListSpecialistPoolsRequest. @@ -596,39 +612,30 @@ def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -675,8 +682,10 @@ def delete_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.DeleteSpecialistPoolRequest. @@ -698,18 +707,11 @@ def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -722,15 +724,16 @@ def delete_specialist_pool(self, # Done; return the response. return response - def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gac_operation.Operation: + def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -777,8 +780,10 @@ def update_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.UpdateSpecialistPoolRequest. @@ -802,18 +807,13 @@ def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = gac_operation.from_gapic( @@ -827,21 +827,14 @@ def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceClient', -) +__all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index f8d4a0e95d..fb550654e8 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -15,24 +15,12 @@ # limitations under the License. # -from .annotation import ( - Annotation, -) -from .annotation_spec import ( - AnnotationSpec, -) -from .artifact import ( - Artifact, -) -from .batch_prediction_job import ( - BatchPredictionJob, -) -from .completion_stats import ( - CompletionStats, -) -from .context import ( - Context, -) +from .annotation import Annotation +from .annotation_spec import AnnotationSpec +from .artifact import Artifact +from .batch_prediction_job import BatchPredictionJob +from .completion_stats import CompletionStats +from .context import Context from .custom_job import ( ContainerSpec, CustomJob, @@ -41,9 +29,7 @@ Scheduling, WorkerPoolSpec, ) -from .data_item import ( - DataItem, -) +from .data_item import DataItem from .data_labeling_job import ( ActiveLearningConfig, DataLabelingJob, @@ -75,15 +61,9 @@ ListDatasetsResponse, UpdateDatasetRequest, ) -from .deployed_index_ref import ( - DeployedIndexRef, -) -from .deployed_model_ref import ( - DeployedModelRef, -) -from .encryption_spec import ( - EncryptionSpec, -) +from .deployed_index_ref import DeployedIndexRef +from .deployed_model_ref import DeployedModelRef +from .encryption_spec import EncryptionSpec from .endpoint import ( DeployedModel, Endpoint, @@ -103,18 +83,10 @@ UndeployModelResponse, UpdateEndpointRequest, ) -from .entity_type import ( - EntityType, -) -from .env_var import ( - EnvVar, -) -from .event import ( - Event, -) -from .execution import ( - Execution, -) +from .entity_type import EntityType +from .env_var import EnvVar +from .event import Event +from .execution import Execution from .explanation import ( Attribution, Explanation, @@ -129,25 +101,15 @@ SmoothGradConfig, XraiAttribution, ) -from .explanation_metadata import ( - ExplanationMetadata, -) -from .feature import ( - Feature, -) -from .feature_monitoring_stats import ( - FeatureStatsAnomaly, -) +from .explanation_metadata import ExplanationMetadata +from .feature import Feature +from .feature_monitoring_stats import FeatureStatsAnomaly from .feature_selector import ( FeatureSelector, IdMatcher, ) -from .featurestore import ( - Featurestore, -) -from .featurestore_monitoring import ( - FeaturestoreMonitoringConfig, -) +from .featurestore import Featurestore +from .featurestore_monitoring import FeaturestoreMonitoringConfig from .featurestore_online_service import ( FeatureValue, FeatureValueList, @@ -193,12 +155,8 @@ UpdateFeaturestoreOperationMetadata, UpdateFeaturestoreRequest, ) -from .hyperparameter_tuning_job import ( - HyperparameterTuningJob, -) -from .index import ( - Index, -) +from .hyperparameter_tuning_job import HyperparameterTuningJob +from .index import Index from .index_endpoint import ( DeployedIndex, DeployedIndexAuthConfig, @@ -279,9 +237,7 @@ UpdateModelDeploymentMonitoringJobOperationMetadata, UpdateModelDeploymentMonitoringJobRequest, ) -from .lineage_subgraph import ( - LineageSubgraph, -) +from .lineage_subgraph import LineageSubgraph from .machine_resources import ( AutomaticResources, AutoscalingMetricSpec, @@ -291,12 +247,8 @@ MachineSpec, ResourcesConsumed, ) -from .manual_batch_tuning_parameters import ( - ManualBatchTuningParameters, -) -from .metadata_schema import ( - MetadataSchema, -) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters +from .metadata_schema import MetadataSchema from .metadata_service import ( AddContextArtifactsAndExecutionsRequest, AddContextArtifactsAndExecutionsResponse, @@ -335,12 +287,8 @@ UpdateContextRequest, UpdateExecutionRequest, ) -from .metadata_store import ( - MetadataStore, -) -from .migratable_resource import ( - MigratableResource, -) +from .metadata_store import MetadataStore +from .migratable_resource import MigratableResource from .migration_service import ( BatchMigrateResourcesOperationMetadata, BatchMigrateResourcesRequest, @@ -364,12 +312,8 @@ ModelMonitoringStatsAnomalies, ModelDeploymentMonitoringObjectiveType, ) -from .model_evaluation import ( - ModelEvaluation, -) -from .model_evaluation_slice import ( - ModelEvaluationSlice, -) +from .model_evaluation import ModelEvaluation +from .model_evaluation_slice import ModelEvaluationSlice from .model_monitoring import ( ModelMonitoringAlertConfig, ModelMonitoringObjectiveConfig, @@ -413,9 +357,7 @@ PredictRequest, PredictResponse, ) -from .specialist_pool import ( - SpecialistPool, -) +from .specialist_pool import SpecialistPool from .specialist_pool_service import ( CreateSpecialistPoolOperationMetadata, CreateSpecialistPoolRequest, @@ -446,9 +388,7 @@ Int64Array, StringArray, ) -from .user_action_reference import ( - UserActionReference, -) +from .user_action_reference import UserActionReference from .vizier_service import ( AddTrialMeasurementRequest, CheckTrialEarlyStoppingStateMetatdata, @@ -475,345 +415,345 @@ ) __all__ = ( - 'AcceleratorType', - 'Annotation', - 'AnnotationSpec', - 'Artifact', - 'BatchPredictionJob', - 'CompletionStats', - 'Context', - 'ContainerSpec', - 'CustomJob', - 'CustomJobSpec', - 'PythonPackageSpec', - 'Scheduling', - 'WorkerPoolSpec', - 'DataItem', - 'ActiveLearningConfig', - 'DataLabelingJob', - 'SampleConfig', - 'TrainingConfig', - 'Dataset', - 'ExportDataConfig', - 'ImportDataConfig', - 'CreateDatasetOperationMetadata', - 'CreateDatasetRequest', - 'DeleteDatasetRequest', - 'ExportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'GetAnnotationSpecRequest', - 'GetDatasetRequest', - 'ImportDataOperationMetadata', - 'ImportDataRequest', - 'ImportDataResponse', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'UpdateDatasetRequest', - 'DeployedIndexRef', - 'DeployedModelRef', - 'EncryptionSpec', - 'DeployedModel', - 'Endpoint', - 'CreateEndpointOperationMetadata', - 'CreateEndpointRequest', - 'DeleteEndpointRequest', - 'DeployModelOperationMetadata', - 'DeployModelRequest', - 'DeployModelResponse', - 'GetEndpointRequest', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'UndeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UpdateEndpointRequest', - 'EntityType', - 'EnvVar', - 'Event', - 'Execution', - 'Attribution', - 'Explanation', - 'ExplanationMetadataOverride', - 'ExplanationParameters', - 'ExplanationSpec', - 'ExplanationSpecOverride', - 'FeatureNoiseSigma', - 'IntegratedGradientsAttribution', - 'ModelExplanation', - 'SampledShapleyAttribution', - 'SmoothGradConfig', - 'XraiAttribution', - 'ExplanationMetadata', - 'Feature', - 'FeatureStatsAnomaly', - 'FeatureSelector', - 'IdMatcher', - 'Featurestore', - 'FeaturestoreMonitoringConfig', - 'FeatureValue', - 'FeatureValueList', - 'ReadFeatureValuesRequest', - 'ReadFeatureValuesResponse', - 'ReadSetting', - 'StreamingReadFeatureValuesRequest', - 'BatchCreateFeaturesOperationMetadata', - 'BatchCreateFeaturesRequest', - 'BatchCreateFeaturesResponse', - 'BatchReadFeatureValuesOperationMetadata', - 'BatchReadFeatureValuesRequest', - 'BatchReadFeatureValuesResponse', - 'CreateEntityTypeOperationMetadata', - 'CreateEntityTypeRequest', - 'CreateFeatureOperationMetadata', - 'CreateFeatureRequest', - 'CreateFeaturestoreOperationMetadata', - 'CreateFeaturestoreRequest', - 'DeleteEntityTypeRequest', - 'DeleteFeatureRequest', - 'DeleteFeaturestoreRequest', - 'DestinationFeatureSetting', - 'FeatureValueDestination', - 'GetEntityTypeRequest', - 'GetFeatureRequest', - 'GetFeaturestoreRequest', - 'ImportFeatureValuesOperationMetadata', - 'ImportFeatureValuesRequest', - 'ImportFeatureValuesResponse', - 'ListEntityTypesRequest', - 'ListEntityTypesResponse', - 'ListFeaturesRequest', - 'ListFeaturesResponse', - 'ListFeaturestoresRequest', - 'ListFeaturestoresResponse', - 'SearchFeaturesRequest', - 'SearchFeaturesResponse', - 'UpdateEntityTypeRequest', - 'UpdateFeatureRequest', - 'UpdateFeaturestoreOperationMetadata', - 'UpdateFeaturestoreRequest', - 'HyperparameterTuningJob', - 'Index', - 'DeployedIndex', - 'DeployedIndexAuthConfig', - 'IndexEndpoint', - 'IndexPrivateEndpoints', - 'CreateIndexEndpointOperationMetadata', - 'CreateIndexEndpointRequest', - 'DeleteIndexEndpointRequest', - 'DeployIndexOperationMetadata', - 'DeployIndexRequest', - 'DeployIndexResponse', - 'GetIndexEndpointRequest', - 'ListIndexEndpointsRequest', - 'ListIndexEndpointsResponse', - 'UndeployIndexOperationMetadata', - 'UndeployIndexRequest', - 'UndeployIndexResponse', - 'UpdateIndexEndpointRequest', - 'CreateIndexOperationMetadata', - 'CreateIndexRequest', - 'DeleteIndexRequest', - 'GetIndexRequest', - 'ListIndexesRequest', - 'ListIndexesResponse', - 'NearestNeighborSearchOperationMetadata', - 'UpdateIndexOperationMetadata', - 'UpdateIndexRequest', - 'AvroSource', - 'BigQueryDestination', - 'BigQuerySource', - 'ContainerRegistryDestination', - 'CsvDestination', - 'CsvSource', - 'GcsDestination', - 'GcsSource', - 'TFRecordDestination', - 'CancelBatchPredictionJobRequest', - 'CancelCustomJobRequest', - 'CancelDataLabelingJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CreateBatchPredictionJobRequest', - 'CreateCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'CreateHyperparameterTuningJobRequest', - 'CreateModelDeploymentMonitoringJobRequest', - 'DeleteBatchPredictionJobRequest', - 'DeleteCustomJobRequest', - 'DeleteDataLabelingJobRequest', - 'DeleteHyperparameterTuningJobRequest', - 'DeleteModelDeploymentMonitoringJobRequest', - 'GetBatchPredictionJobRequest', - 'GetCustomJobRequest', - 'GetDataLabelingJobRequest', - 'GetHyperparameterTuningJobRequest', - 'GetModelDeploymentMonitoringJobRequest', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'ListModelDeploymentMonitoringJobsRequest', - 'ListModelDeploymentMonitoringJobsResponse', - 'PauseModelDeploymentMonitoringJobRequest', - 'ResumeModelDeploymentMonitoringJobRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesRequest', - 'SearchModelDeploymentMonitoringStatsAnomaliesResponse', - 'UpdateModelDeploymentMonitoringJobOperationMetadata', - 'UpdateModelDeploymentMonitoringJobRequest', - 'JobState', - 'LineageSubgraph', - 'AutomaticResources', - 'AutoscalingMetricSpec', - 'BatchDedicatedResources', - 'DedicatedResources', - 'DiskSpec', - 'MachineSpec', - 'ResourcesConsumed', - 'ManualBatchTuningParameters', - 'MetadataSchema', - 'AddContextArtifactsAndExecutionsRequest', - 'AddContextArtifactsAndExecutionsResponse', - 'AddContextChildrenRequest', - 'AddContextChildrenResponse', - 'AddExecutionEventsRequest', - 'AddExecutionEventsResponse', - 'CreateArtifactRequest', - 'CreateContextRequest', - 'CreateExecutionRequest', - 'CreateMetadataSchemaRequest', - 'CreateMetadataStoreOperationMetadata', - 'CreateMetadataStoreRequest', - 'DeleteContextRequest', - 'DeleteMetadataStoreOperationMetadata', - 'DeleteMetadataStoreRequest', - 'GetArtifactRequest', - 'GetContextRequest', - 'GetExecutionRequest', - 'GetMetadataSchemaRequest', - 'GetMetadataStoreRequest', - 'ListArtifactsRequest', - 'ListArtifactsResponse', - 'ListContextsRequest', - 'ListContextsResponse', - 'ListExecutionsRequest', - 'ListExecutionsResponse', - 'ListMetadataSchemasRequest', - 'ListMetadataSchemasResponse', - 'ListMetadataStoresRequest', - 'ListMetadataStoresResponse', - 'QueryArtifactLineageSubgraphRequest', - 'QueryContextLineageSubgraphRequest', - 'QueryExecutionInputsAndOutputsRequest', - 'UpdateArtifactRequest', - 'UpdateContextRequest', - 'UpdateExecutionRequest', - 'MetadataStore', - 'MigratableResource', - 'BatchMigrateResourcesOperationMetadata', - 'BatchMigrateResourcesRequest', - 'BatchMigrateResourcesResponse', - 'MigrateResourceRequest', - 'MigrateResourceResponse', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'Model', - 'ModelContainerSpec', - 'Port', - 'PredictSchemata', - 'ModelDeploymentMonitoringBigQueryTable', - 'ModelDeploymentMonitoringJob', - 'ModelDeploymentMonitoringObjectiveConfig', - 'ModelDeploymentMonitoringScheduleConfig', - 'ModelMonitoringStatsAnomalies', - 'ModelDeploymentMonitoringObjectiveType', - 'ModelEvaluation', - 'ModelEvaluationSlice', - 'ModelMonitoringAlertConfig', - 'ModelMonitoringObjectiveConfig', - 'SamplingStrategy', - 'ThresholdConfig', - 'DeleteModelRequest', - 'ExportModelOperationMetadata', - 'ExportModelRequest', - 'ExportModelResponse', - 'GetModelEvaluationRequest', - 'GetModelEvaluationSliceRequest', - 'GetModelRequest', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'ListModelsRequest', - 'ListModelsResponse', - 'UpdateModelRequest', - 'UploadModelOperationMetadata', - 'UploadModelRequest', - 'UploadModelResponse', - 'DeleteOperationMetadata', - 'GenericOperationMetadata', - 'CancelTrainingPipelineRequest', - 'CreateTrainingPipelineRequest', - 'DeleteTrainingPipelineRequest', - 'GetTrainingPipelineRequest', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'PipelineState', - 'ExplainRequest', - 'ExplainResponse', - 'PredictRequest', - 'PredictResponse', - 'SpecialistPool', - 'CreateSpecialistPoolOperationMetadata', - 'CreateSpecialistPoolRequest', - 'DeleteSpecialistPoolRequest', - 'GetSpecialistPoolRequest', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'UpdateSpecialistPoolOperationMetadata', - 'UpdateSpecialistPoolRequest', - 'Measurement', - 'Study', - 'StudySpec', - 'Trial', - 'FilterSplit', - 'FractionSplit', - 'InputDataConfig', - 'PredefinedSplit', - 'TimestampSplit', - 'TrainingPipeline', - 'BoolArray', - 'DoubleArray', - 'Int64Array', - 'StringArray', - 'UserActionReference', - 'AddTrialMeasurementRequest', - 'CheckTrialEarlyStoppingStateMetatdata', - 'CheckTrialEarlyStoppingStateRequest', - 'CheckTrialEarlyStoppingStateResponse', - 'CompleteTrialRequest', - 'CreateStudyRequest', - 'CreateTrialRequest', - 'DeleteStudyRequest', - 'DeleteTrialRequest', - 'GetStudyRequest', - 'GetTrialRequest', - 'ListOptimalTrialsRequest', - 'ListOptimalTrialsResponse', - 'ListStudiesRequest', - 'ListStudiesResponse', - 'ListTrialsRequest', - 'ListTrialsResponse', - 'LookupStudyRequest', - 'StopTrialRequest', - 'SuggestTrialsMetadata', - 'SuggestTrialsRequest', - 'SuggestTrialsResponse', + "AcceleratorType", + "Annotation", + "AnnotationSpec", + "Artifact", + "BatchPredictionJob", + "CompletionStats", + "Context", + "ContainerSpec", + "CustomJob", + "CustomJobSpec", + "PythonPackageSpec", + "Scheduling", + "WorkerPoolSpec", + "DataItem", + "ActiveLearningConfig", + "DataLabelingJob", + "SampleConfig", + "TrainingConfig", + "Dataset", + "ExportDataConfig", + "ImportDataConfig", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "DeleteDatasetRequest", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "GetAnnotationSpecRequest", + "GetDatasetRequest", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "UpdateDatasetRequest", + "DeployedIndexRef", + "DeployedModelRef", + "EncryptionSpec", + "DeployedModel", + "Endpoint", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateEndpointRequest", + "EntityType", + "EnvVar", + "Event", + "Execution", + "Attribution", + "Explanation", + "ExplanationMetadataOverride", + "ExplanationParameters", + "ExplanationSpec", + "ExplanationSpecOverride", + "FeatureNoiseSigma", + "IntegratedGradientsAttribution", + "ModelExplanation", + "SampledShapleyAttribution", + "SmoothGradConfig", + "XraiAttribution", + "ExplanationMetadata", + "Feature", + "FeatureStatsAnomaly", + "FeatureSelector", + "IdMatcher", + "Featurestore", + "FeaturestoreMonitoringConfig", + "FeatureValue", + "FeatureValueList", + "ReadFeatureValuesRequest", + "ReadFeatureValuesResponse", + "ReadSetting", + "StreamingReadFeatureValuesRequest", + "BatchCreateFeaturesOperationMetadata", + "BatchCreateFeaturesRequest", + "BatchCreateFeaturesResponse", + "BatchReadFeatureValuesOperationMetadata", + "BatchReadFeatureValuesRequest", + "BatchReadFeatureValuesResponse", + "CreateEntityTypeOperationMetadata", + "CreateEntityTypeRequest", + "CreateFeatureOperationMetadata", + "CreateFeatureRequest", + "CreateFeaturestoreOperationMetadata", + "CreateFeaturestoreRequest", + "DeleteEntityTypeRequest", + "DeleteFeatureRequest", + "DeleteFeaturestoreRequest", + "DestinationFeatureSetting", + "FeatureValueDestination", + "GetEntityTypeRequest", + "GetFeatureRequest", + "GetFeaturestoreRequest", + "ImportFeatureValuesOperationMetadata", + "ImportFeatureValuesRequest", + "ImportFeatureValuesResponse", + "ListEntityTypesRequest", + "ListEntityTypesResponse", + "ListFeaturesRequest", + "ListFeaturesResponse", + "ListFeaturestoresRequest", + "ListFeaturestoresResponse", + "SearchFeaturesRequest", + "SearchFeaturesResponse", + "UpdateEntityTypeRequest", + "UpdateFeatureRequest", + "UpdateFeaturestoreOperationMetadata", + "UpdateFeaturestoreRequest", + "HyperparameterTuningJob", + "Index", + "DeployedIndex", + "DeployedIndexAuthConfig", + "IndexEndpoint", + "IndexPrivateEndpoints", + "CreateIndexEndpointOperationMetadata", + "CreateIndexEndpointRequest", + "DeleteIndexEndpointRequest", + "DeployIndexOperationMetadata", + "DeployIndexRequest", + "DeployIndexResponse", + "GetIndexEndpointRequest", + "ListIndexEndpointsRequest", + "ListIndexEndpointsResponse", + "UndeployIndexOperationMetadata", + "UndeployIndexRequest", + "UndeployIndexResponse", + "UpdateIndexEndpointRequest", + "CreateIndexOperationMetadata", + "CreateIndexRequest", + "DeleteIndexRequest", + "GetIndexRequest", + "ListIndexesRequest", + "ListIndexesResponse", + "NearestNeighborSearchOperationMetadata", + "UpdateIndexOperationMetadata", + "UpdateIndexRequest", + "AvroSource", + "BigQueryDestination", + "BigQuerySource", + "ContainerRegistryDestination", + "CsvDestination", + "CsvSource", + "GcsDestination", + "GcsSource", + "TFRecordDestination", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "CreateModelDeploymentMonitoringJobRequest", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteHyperparameterTuningJobRequest", + "DeleteModelDeploymentMonitoringJobRequest", + "GetBatchPredictionJobRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetHyperparameterTuningJobRequest", + "GetModelDeploymentMonitoringJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "ListModelDeploymentMonitoringJobsRequest", + "ListModelDeploymentMonitoringJobsResponse", + "PauseModelDeploymentMonitoringJobRequest", + "ResumeModelDeploymentMonitoringJobRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesRequest", + "SearchModelDeploymentMonitoringStatsAnomaliesResponse", + "UpdateModelDeploymentMonitoringJobOperationMetadata", + "UpdateModelDeploymentMonitoringJobRequest", + "JobState", + "LineageSubgraph", + "AutomaticResources", + "AutoscalingMetricSpec", + "BatchDedicatedResources", + "DedicatedResources", + "DiskSpec", + "MachineSpec", + "ResourcesConsumed", + "ManualBatchTuningParameters", + "MetadataSchema", + "AddContextArtifactsAndExecutionsRequest", + "AddContextArtifactsAndExecutionsResponse", + "AddContextChildrenRequest", + "AddContextChildrenResponse", + "AddExecutionEventsRequest", + "AddExecutionEventsResponse", + "CreateArtifactRequest", + "CreateContextRequest", + "CreateExecutionRequest", + "CreateMetadataSchemaRequest", + "CreateMetadataStoreOperationMetadata", + "CreateMetadataStoreRequest", + "DeleteContextRequest", + "DeleteMetadataStoreOperationMetadata", + "DeleteMetadataStoreRequest", + "GetArtifactRequest", + "GetContextRequest", + "GetExecutionRequest", + "GetMetadataSchemaRequest", + "GetMetadataStoreRequest", + "ListArtifactsRequest", + "ListArtifactsResponse", + "ListContextsRequest", + "ListContextsResponse", + "ListExecutionsRequest", + "ListExecutionsResponse", + "ListMetadataSchemasRequest", + "ListMetadataSchemasResponse", + "ListMetadataStoresRequest", + "ListMetadataStoresResponse", + "QueryArtifactLineageSubgraphRequest", + "QueryContextLineageSubgraphRequest", + "QueryExecutionInputsAndOutputsRequest", + "UpdateArtifactRequest", + "UpdateContextRequest", + "UpdateExecutionRequest", + "MetadataStore", + "MigratableResource", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceRequest", + "MigrateResourceResponse", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "Model", + "ModelContainerSpec", + "Port", + "PredictSchemata", + "ModelDeploymentMonitoringBigQueryTable", + "ModelDeploymentMonitoringJob", + "ModelDeploymentMonitoringObjectiveConfig", + "ModelDeploymentMonitoringScheduleConfig", + "ModelMonitoringStatsAnomalies", + "ModelDeploymentMonitoringObjectiveType", + "ModelEvaluation", + "ModelEvaluationSlice", + "ModelMonitoringAlertConfig", + "ModelMonitoringObjectiveConfig", + "SamplingStrategy", + "ThresholdConfig", + "DeleteModelRequest", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "DeleteOperationMetadata", + "GenericOperationMetadata", + "CancelTrainingPipelineRequest", + "CreateTrainingPipelineRequest", + "DeleteTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "PipelineState", + "ExplainRequest", + "ExplainResponse", + "PredictRequest", + "PredictResponse", + "SpecialistPool", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "DeleteSpecialistPoolRequest", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "Measurement", + "Study", + "StudySpec", + "Trial", + "FilterSplit", + "FractionSplit", + "InputDataConfig", + "PredefinedSplit", + "TimestampSplit", + "TrainingPipeline", + "BoolArray", + "DoubleArray", + "Int64Array", + "StringArray", + "UserActionReference", + "AddTrialMeasurementRequest", + "CheckTrialEarlyStoppingStateMetatdata", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CompleteTrialRequest", + "CreateStudyRequest", + "CreateTrialRequest", + "DeleteStudyRequest", + "DeleteTrialRequest", + "GetStudyRequest", + "GetTrialRequest", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", + "ListStudiesRequest", + "ListStudiesResponse", + "ListTrialsRequest", + "ListTrialsResponse", + "LookupStudyRequest", + "StopTrialRequest", + "SuggestTrialsMetadata", + "SuggestTrialsRequest", + "SuggestTrialsResponse", ) diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py index eee6fd93f9..e6881865ca 100644 --- a/google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py +++ b/google/cloud/aiplatform_v1beta1/types/deployed_index_ref.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'DeployedIndexRef', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"DeployedIndexRef",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/entity_type.py b/google/cloud/aiplatform_v1beta1/types/entity_type.py index 38448a20c3..eabbe9190a 100644 --- a/google/cloud/aiplatform_v1beta1/types/entity_type.py +++ b/google/cloud/aiplatform_v1beta1/types/entity_type.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'EntityType', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"EntityType",}, ) @@ -83,19 +80,17 @@ class EntityType(proto.Message): description = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=6) etag = proto.Field(proto.STRING, number=7) - monitoring_config = proto.Field(proto.MESSAGE, number=8, + monitoring_config = proto.Field( + proto.MESSAGE, + number=8, message=featurestore_monitoring.FeaturestoreMonitoringConfig, ) diff --git a/google/cloud/aiplatform_v1beta1/types/feature.py b/google/cloud/aiplatform_v1beta1/types/feature.py index f34a825fab..eed5209479 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature.py +++ b/google/cloud/aiplatform_v1beta1/types/feature.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Feature', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Feature",}, ) @@ -89,6 +86,7 @@ class Feature(proto.Message): ``FeatureStatsAnomaly.start_time`` descending. """ + class ValueType(proto.Enum): r"""An enum representing the value type of a feature.""" VALUE_TYPE_UNSPECIFIED = 0 @@ -106,28 +104,24 @@ class ValueType(proto.Enum): description = proto.Field(proto.STRING, number=2) - value_type = proto.Field(proto.ENUM, number=3, - enum=ValueType, - ) + value_type = proto.Field(proto.ENUM, number=3, enum=ValueType,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=6) etag = proto.Field(proto.STRING, number=7) - monitoring_config = proto.Field(proto.MESSAGE, number=9, + monitoring_config = proto.Field( + proto.MESSAGE, + number=9, message=featurestore_monitoring.FeaturestoreMonitoringConfig, ) - monitoring_stats = proto.RepeatedField(proto.MESSAGE, number=10, - message=feature_monitoring_stats.FeatureStatsAnomaly, + monitoring_stats = proto.RepeatedField( + proto.MESSAGE, number=10, message=feature_monitoring_stats.FeatureStatsAnomaly, ) diff --git a/google/cloud/aiplatform_v1beta1/types/feature_selector.py b/google/cloud/aiplatform_v1beta1/types/feature_selector.py index 346029f8f7..cda0ff6713 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature_selector.py +++ b/google/cloud/aiplatform_v1beta1/types/feature_selector.py @@ -19,11 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'IdMatcher', - 'FeatureSelector', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"IdMatcher", "FeatureSelector",}, ) @@ -51,9 +48,7 @@ class FeatureSelector(proto.Message): Required. Matches Features based on ID. """ - id_matcher = proto.Field(proto.MESSAGE, number=1, - message='IdMatcher', - ) + id_matcher = proto.Field(proto.MESSAGE, number=1, message="IdMatcher",) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore.py b/google/cloud/aiplatform_v1beta1/types/featurestore.py index 588d587e2f..378b651b42 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Featurestore', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Featurestore",}, ) @@ -74,6 +71,7 @@ class Featurestore(proto.Message): state (google.cloud.aiplatform_v1beta1.types.Featurestore.State): Output only. State of the featurestore. """ + class State(proto.Enum): r"""Possible states a Featurestore can have.""" STATE_UNSPECIFIED = 0 @@ -111,25 +109,19 @@ class OnlineServingConfig(proto.Message): display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=5) labels = proto.MapField(proto.STRING, proto.STRING, number=6) - online_serving_config = proto.Field(proto.MESSAGE, number=7, - message=OnlineServingConfig, + online_serving_config = proto.Field( + proto.MESSAGE, number=7, message=OnlineServingConfig, ) - state = proto.Field(proto.ENUM, number=8, - enum=State, - ) + state = proto.Field(proto.ENUM, number=8, enum=State,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py b/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py index a13e0778f4..815faaa6fb 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py @@ -22,10 +22,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'FeaturestoreMonitoringConfig', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"FeaturestoreMonitoringConfig",}, ) @@ -37,6 +35,7 @@ class FeaturestoreMonitoringConfig(proto.Message): The config for Snapshot Analysis Based Feature Monitoring. """ + class SnapshotAnalysis(proto.Message): r"""Configuration of the Featurestore's Snapshot Analysis Based Monitoring. This type of analysis generates statistics for each @@ -64,13 +63,11 @@ class SnapshotAnalysis(proto.Message): disabled = proto.Field(proto.BOOL, number=1) - monitoring_interval = proto.Field(proto.MESSAGE, number=2, - message=duration.Duration, + monitoring_interval = proto.Field( + proto.MESSAGE, number=2, message=duration.Duration, ) - snapshot_analysis = proto.Field(proto.MESSAGE, number=1, - message=SnapshotAnalysis, - ) + snapshot_analysis = proto.Field(proto.MESSAGE, number=1, message=SnapshotAnalysis,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py index 2564346039..2ca2fe8dae 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py @@ -18,20 +18,22 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import feature_selector as gca_feature_selector +from google.cloud.aiplatform_v1beta1.types import ( + feature_selector as gca_feature_selector, +) from google.cloud.aiplatform_v1beta1.types import types from google.protobuf import timestamp_pb2 as timestamp # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'ReadFeatureValuesRequest', - 'ReadSetting', - 'ReadFeatureValuesResponse', - 'StreamingReadFeatureValuesRequest', - 'FeatureValue', - 'FeatureValueList', + "ReadFeatureValuesRequest", + "ReadSetting", + "ReadFeatureValuesResponse", + "StreamingReadFeatureValuesRequest", + "FeatureValue", + "FeatureValueList", }, ) @@ -71,16 +73,14 @@ class ReadFeatureValuesRequest(proto.Message): entity_id = proto.Field(proto.STRING, number=2) - feature_selector = proto.Field(proto.MESSAGE, number=3, - message=gca_feature_selector.FeatureSelector, + feature_selector = proto.Field( + proto.MESSAGE, number=3, message=gca_feature_selector.FeatureSelector, ) - setting = proto.Field(proto.MESSAGE, number=5, - message='ReadSetting', - ) + setting = proto.Field(proto.MESSAGE, number=5, message="ReadSetting",) - setting_overrides = proto.MapField(proto.STRING, proto.MESSAGE, number=6, - message='ReadSetting', + setting_overrides = proto.MapField( + proto.STRING, proto.MESSAGE, number=6, message="ReadSetting", ) @@ -102,9 +102,7 @@ class ReadSetting(proto.Message): values_count = proto.Field(proto.INT32, number=2) - read_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + read_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) class ReadFeatureValuesResponse(proto.Message): @@ -121,6 +119,7 @@ class ReadFeatureValuesResponse(proto.Message): entity in the Featurestore if values for only some Features were requested. """ + class FeatureDescriptor(proto.Message): r"""Metadata for requested Features. @@ -149,8 +148,10 @@ class Header(proto.Message): entity_type = proto.Field(proto.STRING, number=1) - feature_descriptors = proto.RepeatedField(proto.MESSAGE, number=2, - message='ReadFeatureValuesResponse.FeatureDescriptor', + feature_descriptors = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ReadFeatureValuesResponse.FeatureDescriptor", ) class EntityView(proto.Message): @@ -167,6 +168,7 @@ class EntityView(proto.Message): header ``ReadFeatureValuesResponse.header``. """ + class Data(proto.Message): r"""Container to hold value(s), successive in time, for one Feature from the request. @@ -182,27 +184,25 @@ class Data(proto.Message): instead of being returned as empty. """ - value = proto.Field(proto.MESSAGE, number=1, oneof='data', - message='FeatureValue', + value = proto.Field( + proto.MESSAGE, number=1, oneof="data", message="FeatureValue", ) - values = proto.Field(proto.MESSAGE, number=2, oneof='data', - message='FeatureValueList', + values = proto.Field( + proto.MESSAGE, number=2, oneof="data", message="FeatureValueList", ) entity_id = proto.Field(proto.STRING, number=1) - data = proto.RepeatedField(proto.MESSAGE, number=2, - message='ReadFeatureValuesResponse.EntityView.Data', + data = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ReadFeatureValuesResponse.EntityView.Data", ) - header = proto.Field(proto.MESSAGE, number=1, - message=Header, - ) + header = proto.Field(proto.MESSAGE, number=1, message=Header,) - entity_view = proto.Field(proto.MESSAGE, number=2, - message=EntityView, - ) + entity_view = proto.Field(proto.MESSAGE, number=2, message=EntityView,) class StreamingReadFeatureValuesRequest(proto.Message): @@ -240,16 +240,14 @@ class StreamingReadFeatureValuesRequest(proto.Message): entity_ids = proto.RepeatedField(proto.STRING, number=2) - feature_selector = proto.Field(proto.MESSAGE, number=3, - message=gca_feature_selector.FeatureSelector, + feature_selector = proto.Field( + proto.MESSAGE, number=3, message=gca_feature_selector.FeatureSelector, ) - setting = proto.Field(proto.MESSAGE, number=5, - message='ReadSetting', - ) + setting = proto.Field(proto.MESSAGE, number=5, message="ReadSetting",) - setting_overrides = proto.MapField(proto.STRING, proto.MESSAGE, number=6, - message='ReadSetting', + setting_overrides = proto.MapField( + proto.STRING, proto.MESSAGE, number=6, message="ReadSetting", ) @@ -279,6 +277,7 @@ class FeatureValue(proto.Message): metadata (google.cloud.aiplatform_v1beta1.types.FeatureValue.Metadata): Output only. Metadata of feature value. """ + class Metadata(proto.Message): r"""Metadata of feature value. @@ -291,39 +290,37 @@ class Metadata(proto.Message): store. """ - generate_time = proto.Field(proto.MESSAGE, number=1, - message=timestamp.Timestamp, + generate_time = proto.Field( + proto.MESSAGE, number=1, message=timestamp.Timestamp, ) - bool_value = proto.Field(proto.BOOL, number=1, oneof='value') + bool_value = proto.Field(proto.BOOL, number=1, oneof="value") - double_value = proto.Field(proto.DOUBLE, number=2, oneof='value') + double_value = proto.Field(proto.DOUBLE, number=2, oneof="value") - int64_value = proto.Field(proto.INT64, number=5, oneof='value') + int64_value = proto.Field(proto.INT64, number=5, oneof="value") - string_value = proto.Field(proto.STRING, number=6, oneof='value') + string_value = proto.Field(proto.STRING, number=6, oneof="value") - bool_array_value = proto.Field(proto.MESSAGE, number=7, oneof='value', - message=types.BoolArray, + bool_array_value = proto.Field( + proto.MESSAGE, number=7, oneof="value", message=types.BoolArray, ) - double_array_value = proto.Field(proto.MESSAGE, number=8, oneof='value', - message=types.DoubleArray, + double_array_value = proto.Field( + proto.MESSAGE, number=8, oneof="value", message=types.DoubleArray, ) - int64_array_value = proto.Field(proto.MESSAGE, number=11, oneof='value', - message=types.Int64Array, + int64_array_value = proto.Field( + proto.MESSAGE, number=11, oneof="value", message=types.Int64Array, ) - string_array_value = proto.Field(proto.MESSAGE, number=12, oneof='value', - message=types.StringArray, + string_array_value = proto.Field( + proto.MESSAGE, number=12, oneof="value", message=types.StringArray, ) - bytes_value = proto.Field(proto.BYTES, number=13, oneof='value') + bytes_value = proto.Field(proto.BYTES, number=13, oneof="value") - metadata = proto.Field(proto.MESSAGE, number=14, - message=Metadata, - ) + metadata = proto.Field(proto.MESSAGE, number=14, message=Metadata,) class FeatureValueList(proto.Message): @@ -335,9 +332,7 @@ class FeatureValueList(proto.Message): be the same data type. """ - values = proto.RepeatedField(proto.MESSAGE, number=1, - message='FeatureValue', - ) + values = proto.RepeatedField(proto.MESSAGE, number=1, message="FeatureValue",) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py index 1844d2ac15..6bf6c284b2 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py @@ -20,7 +20,9 @@ from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type from google.cloud.aiplatform_v1beta1.types import feature as gca_feature -from google.cloud.aiplatform_v1beta1.types import feature_selector as gca_feature_selector +from google.cloud.aiplatform_v1beta1.types import ( + feature_selector as gca_feature_selector, +) from google.cloud.aiplatform_v1beta1.types import featurestore as gca_featurestore from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import operation @@ -29,43 +31,43 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateFeaturestoreRequest', - 'GetFeaturestoreRequest', - 'ListFeaturestoresRequest', - 'ListFeaturestoresResponse', - 'UpdateFeaturestoreRequest', - 'DeleteFeaturestoreRequest', - 'ImportFeatureValuesRequest', - 'ImportFeatureValuesResponse', - 'BatchReadFeatureValuesRequest', - 'DestinationFeatureSetting', - 'FeatureValueDestination', - 'BatchReadFeatureValuesResponse', - 'CreateEntityTypeRequest', - 'GetEntityTypeRequest', - 'ListEntityTypesRequest', - 'ListEntityTypesResponse', - 'UpdateEntityTypeRequest', - 'DeleteEntityTypeRequest', - 'CreateFeatureRequest', - 'BatchCreateFeaturesRequest', - 'BatchCreateFeaturesResponse', - 'GetFeatureRequest', - 'ListFeaturesRequest', - 'ListFeaturesResponse', - 'SearchFeaturesRequest', - 'SearchFeaturesResponse', - 'UpdateFeatureRequest', - 'DeleteFeatureRequest', - 'CreateFeaturestoreOperationMetadata', - 'UpdateFeaturestoreOperationMetadata', - 'ImportFeatureValuesOperationMetadata', - 'BatchReadFeatureValuesOperationMetadata', - 'CreateEntityTypeOperationMetadata', - 'CreateFeatureOperationMetadata', - 'BatchCreateFeaturesOperationMetadata', + "CreateFeaturestoreRequest", + "GetFeaturestoreRequest", + "ListFeaturestoresRequest", + "ListFeaturestoresResponse", + "UpdateFeaturestoreRequest", + "DeleteFeaturestoreRequest", + "ImportFeatureValuesRequest", + "ImportFeatureValuesResponse", + "BatchReadFeatureValuesRequest", + "DestinationFeatureSetting", + "FeatureValueDestination", + "BatchReadFeatureValuesResponse", + "CreateEntityTypeRequest", + "GetEntityTypeRequest", + "ListEntityTypesRequest", + "ListEntityTypesResponse", + "UpdateEntityTypeRequest", + "DeleteEntityTypeRequest", + "CreateFeatureRequest", + "BatchCreateFeaturesRequest", + "BatchCreateFeaturesResponse", + "GetFeatureRequest", + "ListFeaturesRequest", + "ListFeaturesResponse", + "SearchFeaturesRequest", + "SearchFeaturesResponse", + "UpdateFeatureRequest", + "DeleteFeatureRequest", + "CreateFeaturestoreOperationMetadata", + "UpdateFeaturestoreOperationMetadata", + "ImportFeatureValuesOperationMetadata", + "BatchReadFeatureValuesOperationMetadata", + "CreateEntityTypeOperationMetadata", + "CreateFeatureOperationMetadata", + "BatchCreateFeaturesOperationMetadata", }, ) @@ -94,8 +96,8 @@ class CreateFeaturestoreRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - featurestore = proto.Field(proto.MESSAGE, number=2, - message=gca_featurestore.Featurestore, + featurestore = proto.Field( + proto.MESSAGE, number=2, message=gca_featurestore.Featurestore, ) featurestore_id = proto.Field(proto.STRING, number=3) @@ -179,9 +181,7 @@ class ListFeaturestoresRequest(proto.Message): order_by = proto.Field(proto.STRING, number=5) - read_mask = proto.Field(proto.MESSAGE, number=6, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=6, message=field_mask.FieldMask,) class ListFeaturestoresResponse(proto.Message): @@ -202,8 +202,8 @@ class ListFeaturestoresResponse(proto.Message): def raw_page(self): return self - featurestores = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_featurestore.Featurestore, + featurestores = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_featurestore.Featurestore, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -236,13 +236,11 @@ class UpdateFeaturestoreRequest(proto.Message): - ``online_serving_config.max_online_serving_size`` """ - featurestore = proto.Field(proto.MESSAGE, number=1, - message=gca_featurestore.Featurestore, + featurestore = proto.Field( + proto.MESSAGE, number=1, message=gca_featurestore.Featurestore, ) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteFeaturestoreRequest(proto.Message): @@ -310,6 +308,7 @@ class ImportFeatureValuesRequest(proto.Message): value must be greater than 0, and less than or equal to 100. """ + class FeatureSpec(proto.Message): r"""Defines the Feature value(s) to import. @@ -328,21 +327,26 @@ class FeatureSpec(proto.Message): source_field = proto.Field(proto.STRING, number=2) - avro_source = proto.Field(proto.MESSAGE, number=2, oneof='source', - message=io.AvroSource, + avro_source = proto.Field( + proto.MESSAGE, number=2, oneof="source", message=io.AvroSource, ) - bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', - message=io.BigQuerySource, + bigquery_source = proto.Field( + proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, ) - csv_source = proto.Field(proto.MESSAGE, number=4, oneof='source', - message=io.CsvSource, + csv_source = proto.Field( + proto.MESSAGE, number=4, oneof="source", message=io.CsvSource, ) - feature_time_field = proto.Field(proto.STRING, number=6, oneof='feature_time_source') + feature_time_field = proto.Field( + proto.STRING, number=6, oneof="feature_time_source" + ) - feature_time = proto.Field(proto.MESSAGE, number=7, oneof='feature_time_source', + feature_time = proto.Field( + proto.MESSAGE, + number=7, + oneof="feature_time_source", message=timestamp.Timestamp, ) @@ -350,9 +354,7 @@ class FeatureSpec(proto.Message): entity_id_field = proto.Field(proto.STRING, number=5) - feature_specs = proto.RepeatedField(proto.MESSAGE, number=8, - message=FeatureSpec, - ) + feature_specs = proto.RepeatedField(proto.MESSAGE, number=8, message=FeatureSpec,) disable_online_serving = proto.Field(proto.BOOL, number=9) @@ -420,6 +422,7 @@ class BatchReadFeatureValuesRequest(proto.Message): a column specifying entity IDs in tha EntityType in [BatchReadFeatureValuesRequest.request][] . """ + class EntityTypeSpec(proto.Message): r"""Selects Features of an EntityType to read values of and specifies read settings. @@ -439,26 +442,26 @@ class EntityTypeSpec(proto.Message): entity_type_id = proto.Field(proto.STRING, number=1) - feature_selector = proto.Field(proto.MESSAGE, number=2, - message=gca_feature_selector.FeatureSelector, + feature_selector = proto.Field( + proto.MESSAGE, number=2, message=gca_feature_selector.FeatureSelector, ) - settings = proto.RepeatedField(proto.MESSAGE, number=3, - message='DestinationFeatureSetting', + settings = proto.RepeatedField( + proto.MESSAGE, number=3, message="DestinationFeatureSetting", ) - csv_read_instances = proto.Field(proto.MESSAGE, number=3, oneof='read_option', - message=io.CsvSource, + csv_read_instances = proto.Field( + proto.MESSAGE, number=3, oneof="read_option", message=io.CsvSource, ) featurestore = proto.Field(proto.STRING, number=1) - destination = proto.Field(proto.MESSAGE, number=4, - message='FeatureValueDestination', + destination = proto.Field( + proto.MESSAGE, number=4, message="FeatureValueDestination", ) - entity_type_specs = proto.RepeatedField(proto.MESSAGE, number=7, - message=EntityTypeSpec, + entity_type_specs = proto.RepeatedField( + proto.MESSAGE, number=7, message=EntityTypeSpec, ) @@ -509,16 +512,16 @@ class FeatureValueDestination(proto.Message): types are not allowed in CSV format. """ - bigquery_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', - message=io.BigQueryDestination, + bigquery_destination = proto.Field( + proto.MESSAGE, number=1, oneof="destination", message=io.BigQueryDestination, ) - tfrecord_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', - message=io.TFRecordDestination, + tfrecord_destination = proto.Field( + proto.MESSAGE, number=2, oneof="destination", message=io.TFRecordDestination, ) - csv_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', - message=io.CsvDestination, + csv_destination = proto.Field( + proto.MESSAGE, number=3, oneof="destination", message=io.CsvDestination, ) @@ -552,8 +555,8 @@ class CreateEntityTypeRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - entity_type = proto.Field(proto.MESSAGE, number=2, - message=gca_entity_type.EntityType, + entity_type = proto.Field( + proto.MESSAGE, number=2, message=gca_entity_type.EntityType, ) entity_type_id = proto.Field(proto.STRING, number=3) @@ -640,9 +643,7 @@ class ListEntityTypesRequest(proto.Message): order_by = proto.Field(proto.STRING, number=5) - read_mask = proto.Field(proto.MESSAGE, number=6, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=6, message=field_mask.FieldMask,) class ListEntityTypesResponse(proto.Message): @@ -663,8 +664,8 @@ class ListEntityTypesResponse(proto.Message): def raw_page(self): return self - entity_types = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_entity_type.EntityType, + entity_types = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_entity_type.EntityType, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -697,13 +698,11 @@ class UpdateEntityTypeRequest(proto.Message): - ``monitoring_config.snapshot_analysis.monitoring_interval`` """ - entity_type = proto.Field(proto.MESSAGE, number=1, - message=gca_entity_type.EntityType, + entity_type = proto.Field( + proto.MESSAGE, number=1, message=gca_entity_type.EntityType, ) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteEntityTypeRequest(proto.Message): @@ -748,9 +747,7 @@ class CreateFeatureRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - feature = proto.Field(proto.MESSAGE, number=2, - message=gca_feature.Feature, - ) + feature = proto.Field(proto.MESSAGE, number=2, message=gca_feature.Feature,) feature_id = proto.Field(proto.STRING, number=3) @@ -775,8 +772,8 @@ class BatchCreateFeaturesRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - requests = proto.RepeatedField(proto.MESSAGE, number=2, - message='CreateFeatureRequest', + requests = proto.RepeatedField( + proto.MESSAGE, number=2, message="CreateFeatureRequest", ) @@ -789,8 +786,8 @@ class BatchCreateFeaturesResponse(proto.Message): The Features created. """ - features = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_feature.Feature, + features = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_feature.Feature, ) @@ -884,9 +881,7 @@ class ListFeaturesRequest(proto.Message): order_by = proto.Field(proto.STRING, number=5) - read_mask = proto.Field(proto.MESSAGE, number=6, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=6, message=field_mask.FieldMask,) latest_stats_count = proto.Field(proto.INT32, number=7) @@ -909,8 +904,8 @@ class ListFeaturesResponse(proto.Message): def raw_page(self): return self - features = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_feature.Feature, + features = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_feature.Feature, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -1046,8 +1041,8 @@ class SearchFeaturesResponse(proto.Message): def raw_page(self): return self - features = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_feature.Feature, + features = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_feature.Feature, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -1079,13 +1074,9 @@ class UpdateFeatureRequest(proto.Message): - ``monitoring_config.snapshot_analysis.monitoring_interval`` """ - feature = proto.Field(proto.MESSAGE, number=1, - message=gca_feature.Feature, - ) + feature = proto.Field(proto.MESSAGE, number=1, message=gca_feature.Feature,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteFeatureRequest(proto.Message): @@ -1109,8 +1100,8 @@ class CreateFeaturestoreOperationMetadata(proto.Message): Operation metadata for Featurestore. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -1122,8 +1113,8 @@ class UpdateFeaturestoreOperationMetadata(proto.Message): Operation metadata for Featurestore. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -1142,8 +1133,8 @@ class ImportFeatureValuesOperationMetadata(proto.Message): imported by the operation. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) imported_entity_count = proto.Field(proto.INT64, number=2) @@ -1160,8 +1151,8 @@ class BatchReadFeatureValuesOperationMetadata(proto.Message): read Features values. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -1173,8 +1164,8 @@ class CreateEntityTypeOperationMetadata(proto.Message): Operation metadata for EntityType. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -1186,8 +1177,8 @@ class CreateFeatureOperationMetadata(proto.Message): Operation metadata for Feature. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -1199,8 +1190,8 @@ class BatchCreateFeaturesOperationMetadata(proto.Message): Operation metadata for Feature. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/index.py b/google/cloud/aiplatform_v1beta1/types/index.py index abf5ebf8ac..54443285d9 100644 --- a/google/cloud/aiplatform_v1beta1/types/index.py +++ b/google/cloud/aiplatform_v1beta1/types/index.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Index', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Index",}, ) @@ -103,25 +100,19 @@ class Index(proto.Message): metadata_schema_uri = proto.Field(proto.STRING, number=4) - metadata = proto.Field(proto.MESSAGE, number=6, - message=struct.Value, - ) + metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - deployed_indexes = proto.RepeatedField(proto.MESSAGE, number=7, - message=deployed_index_ref.DeployedIndexRef, + deployed_indexes = proto.RepeatedField( + proto.MESSAGE, number=7, message=deployed_index_ref.DeployedIndexRef, ) etag = proto.Field(proto.STRING, number=8) labels = proto.MapField(proto.STRING, proto.STRING, number=9) - create_time = proto.Field(proto.MESSAGE, number=10, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py index f1bbd1b62a..28ce15cc75 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py @@ -23,12 +23,12 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'IndexEndpoint', - 'DeployedIndex', - 'DeployedIndexAuthConfig', - 'IndexPrivateEndpoints', + "IndexEndpoint", + "DeployedIndex", + "DeployedIndexAuthConfig", + "IndexPrivateEndpoints", }, ) @@ -97,21 +97,17 @@ class IndexEndpoint(proto.Message): description = proto.Field(proto.STRING, number=3) - deployed_indexes = proto.RepeatedField(proto.MESSAGE, number=4, - message='DeployedIndex', + deployed_indexes = proto.RepeatedField( + proto.MESSAGE, number=4, message="DeployedIndex", ) etag = proto.Field(proto.STRING, number=5) labels = proto.MapField(proto.STRING, proto.STRING, number=6) - create_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) network = proto.Field(proto.STRING, number=9) @@ -194,26 +190,22 @@ class DeployedIndex(proto.Message): display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - private_endpoints = proto.Field(proto.MESSAGE, number=5, - message='IndexPrivateEndpoints', + private_endpoints = proto.Field( + proto.MESSAGE, number=5, message="IndexPrivateEndpoints", ) - index_sync_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + index_sync_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - automatic_resources = proto.Field(proto.MESSAGE, number=7, - message=machine_resources.AutomaticResources, + automatic_resources = proto.Field( + proto.MESSAGE, number=7, message=machine_resources.AutomaticResources, ) enable_access_logging = proto.Field(proto.BOOL, number=8) - deployed_index_auth_config = proto.Field(proto.MESSAGE, number=9, - message='DeployedIndexAuthConfig', + deployed_index_auth_config = proto.Field( + proto.MESSAGE, number=9, message="DeployedIndexAuthConfig", ) @@ -226,6 +218,7 @@ class DeployedIndexAuthConfig(proto.Message): Defines the authentication provider that the DeployedIndex uses. """ + class AuthProvider(proto.Message): r"""Configuration for an authentication provider, including support for `JSON Web Token @@ -241,9 +234,7 @@ class AuthProvider(proto.Message): audiences = proto.RepeatedField(proto.STRING, number=1) - auth_provider = proto.Field(proto.MESSAGE, number=1, - message=AuthProvider, - ) + auth_provider = proto.Field(proto.MESSAGE, number=1, message=AuthProvider,) class IndexPrivateEndpoints(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py index cf5abb0c5a..80b2e04b22 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py @@ -24,21 +24,21 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateIndexEndpointRequest', - 'CreateIndexEndpointOperationMetadata', - 'GetIndexEndpointRequest', - 'ListIndexEndpointsRequest', - 'ListIndexEndpointsResponse', - 'UpdateIndexEndpointRequest', - 'DeleteIndexEndpointRequest', - 'DeployIndexRequest', - 'DeployIndexResponse', - 'DeployIndexOperationMetadata', - 'UndeployIndexRequest', - 'UndeployIndexResponse', - 'UndeployIndexOperationMetadata', + "CreateIndexEndpointRequest", + "CreateIndexEndpointOperationMetadata", + "GetIndexEndpointRequest", + "ListIndexEndpointsRequest", + "ListIndexEndpointsResponse", + "UpdateIndexEndpointRequest", + "DeleteIndexEndpointRequest", + "DeployIndexRequest", + "DeployIndexResponse", + "DeployIndexOperationMetadata", + "UndeployIndexRequest", + "UndeployIndexResponse", + "UndeployIndexOperationMetadata", }, ) @@ -58,8 +58,8 @@ class CreateIndexEndpointRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - index_endpoint = proto.Field(proto.MESSAGE, number=2, - message=gca_index_endpoint.IndexEndpoint, + index_endpoint = proto.Field( + proto.MESSAGE, number=2, message=gca_index_endpoint.IndexEndpoint, ) @@ -72,8 +72,8 @@ class CreateIndexEndpointOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -145,9 +145,7 @@ class ListIndexEndpointsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListIndexEndpointsResponse(proto.Message): @@ -167,8 +165,8 @@ class ListIndexEndpointsResponse(proto.Message): def raw_page(self): return self - index_endpoints = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_index_endpoint.IndexEndpoint, + index_endpoints = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_index_endpoint.IndexEndpoint, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -187,13 +185,11 @@ class UpdateIndexEndpointRequest(proto.Message): `FieldMask `__. """ - index_endpoint = proto.Field(proto.MESSAGE, number=1, - message=gca_index_endpoint.IndexEndpoint, + index_endpoint = proto.Field( + proto.MESSAGE, number=1, message=gca_index_endpoint.IndexEndpoint, ) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteIndexEndpointRequest(proto.Message): @@ -226,8 +222,8 @@ class DeployIndexRequest(proto.Message): index_endpoint = proto.Field(proto.STRING, number=1) - deployed_index = proto.Field(proto.MESSAGE, number=2, - message=gca_index_endpoint.DeployedIndex, + deployed_index = proto.Field( + proto.MESSAGE, number=2, message=gca_index_endpoint.DeployedIndex, ) @@ -241,8 +237,8 @@ class DeployIndexResponse(proto.Message): the IndexEndpoint. """ - deployed_index = proto.Field(proto.MESSAGE, number=1, - message=gca_index_endpoint.DeployedIndex, + deployed_index = proto.Field( + proto.MESSAGE, number=1, message=gca_index_endpoint.DeployedIndex, ) @@ -255,8 +251,8 @@ class DeployIndexOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -294,8 +290,8 @@ class UndeployIndexOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/index_service.py b/google/cloud/aiplatform_v1beta1/types/index_service.py index 56cb293e93..601f64c6e8 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_service.py +++ b/google/cloud/aiplatform_v1beta1/types/index_service.py @@ -24,17 +24,17 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateIndexRequest', - 'CreateIndexOperationMetadata', - 'GetIndexRequest', - 'ListIndexesRequest', - 'ListIndexesResponse', - 'UpdateIndexRequest', - 'UpdateIndexOperationMetadata', - 'DeleteIndexRequest', - 'NearestNeighborSearchOperationMetadata', + "CreateIndexRequest", + "CreateIndexOperationMetadata", + "GetIndexRequest", + "ListIndexesRequest", + "ListIndexesResponse", + "UpdateIndexRequest", + "UpdateIndexOperationMetadata", + "DeleteIndexRequest", + "NearestNeighborSearchOperationMetadata", }, ) @@ -54,9 +54,7 @@ class CreateIndexRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - index = proto.Field(proto.MESSAGE, number=2, - message=gca_index.Index, - ) + index = proto.Field(proto.MESSAGE, number=2, message=gca_index.Index,) class CreateIndexOperationMetadata(proto.Message): @@ -71,12 +69,12 @@ class CreateIndexOperationMetadata(proto.Message): Index operation. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - nearest_neighbor_search_operation_metadata = proto.Field(proto.MESSAGE, number=2, - message='NearestNeighborSearchOperationMetadata', + nearest_neighbor_search_operation_metadata = proto.Field( + proto.MESSAGE, number=2, message="NearestNeighborSearchOperationMetadata", ) @@ -124,9 +122,7 @@ class ListIndexesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListIndexesResponse(proto.Message): @@ -146,9 +142,7 @@ class ListIndexesResponse(proto.Message): def raw_page(self): return self - indexes = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_index.Index, - ) + indexes = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_index.Index,) next_page_token = proto.Field(proto.STRING, number=2) @@ -166,13 +160,9 @@ class UpdateIndexRequest(proto.Message): `FieldMask `__. """ - index = proto.Field(proto.MESSAGE, number=1, - message=gca_index.Index, - ) + index = proto.Field(proto.MESSAGE, number=1, message=gca_index.Index,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateIndexOperationMetadata(proto.Message): @@ -187,12 +177,12 @@ class UpdateIndexOperationMetadata(proto.Message): Index operation. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - nearest_neighbor_search_operation_metadata = proto.Field(proto.MESSAGE, number=2, - message='NearestNeighborSearchOperationMetadata', + nearest_neighbor_search_operation_metadata = proto.Field( + proto.MESSAGE, number=2, message="NearestNeighborSearchOperationMetadata", ) @@ -223,6 +213,7 @@ class NearestNeighborSearchOperationMetadata(proto.Message): or has unsupported file format, we will not have the stats for those files. """ + class RecordError(proto.Message): r""" @@ -242,6 +233,7 @@ class RecordError(proto.Message): raw_record (str): The original content of this record. """ + class RecordErrorType(proto.Enum): r"""""" ERROR_TYPE_UNSPECIFIED = 0 @@ -253,8 +245,10 @@ class RecordErrorType(proto.Enum): EMBEDDING_SIZE_MISMATCH = 6 NAMESPACE_MISSING = 7 - error_type = proto.Field(proto.ENUM, number=1, - enum='NearestNeighborSearchOperationMetadata.RecordError.RecordErrorType', + error_type = proto.Field( + proto.ENUM, + number=1, + enum="NearestNeighborSearchOperationMetadata.RecordError.RecordErrorType", ) error_message = proto.Field(proto.STRING, number=2) @@ -291,12 +285,14 @@ class ContentValidationStats(proto.Message): invalid_record_count = proto.Field(proto.INT64, number=3) - partial_errors = proto.RepeatedField(proto.MESSAGE, number=4, - message='NearestNeighborSearchOperationMetadata.RecordError', + partial_errors = proto.RepeatedField( + proto.MESSAGE, + number=4, + message="NearestNeighborSearchOperationMetadata.RecordError", ) - content_validation_stats = proto.RepeatedField(proto.MESSAGE, number=1, - message=ContentValidationStats, + content_validation_stats = proto.RepeatedField( + proto.MESSAGE, number=1, message=ContentValidationStats, ) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index 72e3e24e7a..e18a20b132 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -19,17 +19,17 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'AvroSource', - 'CsvSource', - 'GcsSource', - 'GcsDestination', - 'BigQuerySource', - 'BigQueryDestination', - 'CsvDestination', - 'TFRecordDestination', - 'ContainerRegistryDestination', + "AvroSource", + "CsvSource", + "GcsSource", + "GcsDestination", + "BigQuerySource", + "BigQueryDestination", + "CsvDestination", + "TFRecordDestination", + "ContainerRegistryDestination", }, ) @@ -42,9 +42,7 @@ class AvroSource(proto.Message): Required. Google Cloud Storage location. """ - gcs_source = proto.Field(proto.MESSAGE, number=1, - message='GcsSource', - ) + gcs_source = proto.Field(proto.MESSAGE, number=1, message="GcsSource",) class CsvSource(proto.Message): @@ -55,9 +53,7 @@ class CsvSource(proto.Message): Required. Google Cloud Storage location. """ - gcs_source = proto.Field(proto.MESSAGE, number=1, - message='GcsSource', - ) + gcs_source = proto.Field(proto.MESSAGE, number=1, message="GcsSource",) class GcsSource(proto.Message): @@ -133,9 +129,7 @@ class CsvDestination(proto.Message): Required. Google Cloud Storage location. """ - gcs_destination = proto.Field(proto.MESSAGE, number=1, - message='GcsDestination', - ) + gcs_destination = proto.Field(proto.MESSAGE, number=1, message="GcsDestination",) class TFRecordDestination(proto.Message): @@ -146,9 +140,7 @@ class TFRecordDestination(proto.Message): Required. Google Cloud Storage location. """ - gcs_destination = proto.Field(proto.MESSAGE, number=1, - message='GcsDestination', - ) + gcs_destination = proto.Field(proto.MESSAGE, number=1, message="GcsDestination",) class ContainerRegistryDestination(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_service.py b/google/cloud/aiplatform_v1beta1/types/metadata_service.py index 96ceb992ad..20d13257e7 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_service.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_service.py @@ -29,44 +29,44 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateMetadataStoreRequest', - 'CreateMetadataStoreOperationMetadata', - 'GetMetadataStoreRequest', - 'ListMetadataStoresRequest', - 'ListMetadataStoresResponse', - 'DeleteMetadataStoreRequest', - 'DeleteMetadataStoreOperationMetadata', - 'CreateArtifactRequest', - 'GetArtifactRequest', - 'ListArtifactsRequest', - 'ListArtifactsResponse', - 'UpdateArtifactRequest', - 'CreateContextRequest', - 'GetContextRequest', - 'ListContextsRequest', - 'ListContextsResponse', - 'UpdateContextRequest', - 'DeleteContextRequest', - 'AddContextArtifactsAndExecutionsRequest', - 'AddContextArtifactsAndExecutionsResponse', - 'AddContextChildrenRequest', - 'AddContextChildrenResponse', - 'QueryContextLineageSubgraphRequest', - 'CreateExecutionRequest', - 'GetExecutionRequest', - 'ListExecutionsRequest', - 'ListExecutionsResponse', - 'UpdateExecutionRequest', - 'AddExecutionEventsRequest', - 'AddExecutionEventsResponse', - 'QueryExecutionInputsAndOutputsRequest', - 'CreateMetadataSchemaRequest', - 'GetMetadataSchemaRequest', - 'ListMetadataSchemasRequest', - 'ListMetadataSchemasResponse', - 'QueryArtifactLineageSubgraphRequest', + "CreateMetadataStoreRequest", + "CreateMetadataStoreOperationMetadata", + "GetMetadataStoreRequest", + "ListMetadataStoresRequest", + "ListMetadataStoresResponse", + "DeleteMetadataStoreRequest", + "DeleteMetadataStoreOperationMetadata", + "CreateArtifactRequest", + "GetArtifactRequest", + "ListArtifactsRequest", + "ListArtifactsResponse", + "UpdateArtifactRequest", + "CreateContextRequest", + "GetContextRequest", + "ListContextsRequest", + "ListContextsResponse", + "UpdateContextRequest", + "DeleteContextRequest", + "AddContextArtifactsAndExecutionsRequest", + "AddContextArtifactsAndExecutionsResponse", + "AddContextChildrenRequest", + "AddContextChildrenResponse", + "QueryContextLineageSubgraphRequest", + "CreateExecutionRequest", + "GetExecutionRequest", + "ListExecutionsRequest", + "ListExecutionsResponse", + "UpdateExecutionRequest", + "AddExecutionEventsRequest", + "AddExecutionEventsResponse", + "QueryExecutionInputsAndOutputsRequest", + "CreateMetadataSchemaRequest", + "GetMetadataSchemaRequest", + "ListMetadataSchemasRequest", + "ListMetadataSchemasResponse", + "QueryArtifactLineageSubgraphRequest", }, ) @@ -97,8 +97,8 @@ class CreateMetadataStoreRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - metadata_store = proto.Field(proto.MESSAGE, number=2, - message=gca_metadata_store.MetadataStore, + metadata_store = proto.Field( + proto.MESSAGE, number=2, message=gca_metadata_store.MetadataStore, ) metadata_store_id = proto.Field(proto.STRING, number=3) @@ -114,8 +114,8 @@ class CreateMetadataStoreOperationMetadata(proto.Message): MetadataStore. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -182,8 +182,8 @@ class ListMetadataStoresResponse(proto.Message): def raw_page(self): return self - metadata_stores = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_metadata_store.MetadataStore, + metadata_stores = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_metadata_store.MetadataStore, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -220,8 +220,8 @@ class DeleteMetadataStoreOperationMetadata(proto.Message): MetadataStore. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -250,9 +250,7 @@ class CreateArtifactRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - artifact = proto.Field(proto.MESSAGE, number=2, - message=gca_artifact.Artifact, - ) + artifact = proto.Field(proto.MESSAGE, number=2, message=gca_artifact.Artifact,) artifact_id = proto.Field(proto.STRING, number=3) @@ -325,8 +323,8 @@ class ListArtifactsResponse(proto.Message): def raw_page(self): return self - artifacts = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_artifact.Artifact, + artifacts = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_artifact.Artifact, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -354,13 +352,9 @@ class UpdateArtifactRequest(proto.Message): created. In this situation, ``update_mask`` is ignored. """ - artifact = proto.Field(proto.MESSAGE, number=1, - message=gca_artifact.Artifact, - ) + artifact = proto.Field(proto.MESSAGE, number=1, message=gca_artifact.Artifact,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) allow_missing = proto.Field(proto.BOOL, number=3) @@ -390,9 +384,7 @@ class CreateContextRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - context = proto.Field(proto.MESSAGE, number=2, - message=gca_context.Context, - ) + context = proto.Field(proto.MESSAGE, number=2, message=gca_context.Context,) context_id = proto.Field(proto.STRING, number=3) @@ -465,8 +457,8 @@ class ListContextsResponse(proto.Message): def raw_page(self): return self - contexts = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_context.Context, + contexts = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_context.Context, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -493,13 +485,9 @@ class UpdateContextRequest(proto.Message): created. In this situation, ``update_mask`` is ignored. """ - context = proto.Field(proto.MESSAGE, number=1, - message=gca_context.Context, - ) + context = proto.Field(proto.MESSAGE, number=1, message=gca_context.Context,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) allow_missing = proto.Field(proto.BOOL, number=3) @@ -625,9 +613,7 @@ class CreateExecutionRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - execution = proto.Field(proto.MESSAGE, number=2, - message=gca_execution.Execution, - ) + execution = proto.Field(proto.MESSAGE, number=2, message=gca_execution.Execution,) execution_id = proto.Field(proto.STRING, number=3) @@ -706,8 +692,8 @@ class ListExecutionsResponse(proto.Message): def raw_page(self): return self - executions = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_execution.Execution, + executions = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_execution.Execution, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -735,13 +721,9 @@ class UpdateExecutionRequest(proto.Message): be created. In this situation, ``update_mask`` is ignored. """ - execution = proto.Field(proto.MESSAGE, number=1, - message=gca_execution.Execution, - ) + execution = proto.Field(proto.MESSAGE, number=1, message=gca_execution.Execution,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) allow_missing = proto.Field(proto.BOOL, number=3) @@ -762,9 +744,7 @@ class AddExecutionEventsRequest(proto.Message): execution = proto.Field(proto.STRING, number=1) - events = proto.RepeatedField(proto.MESSAGE, number=2, - message=event.Event, - ) + events = proto.RepeatedField(proto.MESSAGE, number=2, message=event.Event,) class AddExecutionEventsResponse(proto.Message): @@ -815,8 +795,8 @@ class CreateMetadataSchemaRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - metadata_schema = proto.Field(proto.MESSAGE, number=2, - message=gca_metadata_schema.MetadataSchema, + metadata_schema = proto.Field( + proto.MESSAGE, number=2, message=gca_metadata_schema.MetadataSchema, ) metadata_schema_id = proto.Field(proto.STRING, number=3) @@ -891,8 +871,8 @@ class ListMetadataSchemasResponse(proto.Message): def raw_page(self): return self - metadata_schemas = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_metadata_schema.MetadataSchema, + metadata_schemas = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_metadata_schema.MetadataSchema, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/types.py b/google/cloud/aiplatform_v1beta1/types/types.py index c2803a3c3a..127833d18e 100644 --- a/google/cloud/aiplatform_v1beta1/types/types.py +++ b/google/cloud/aiplatform_v1beta1/types/types.py @@ -19,13 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'BoolArray', - 'DoubleArray', - 'Int64Array', - 'StringArray', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"BoolArray", "DoubleArray", "Int64Array", "StringArray",}, ) diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 3fe62e7836..76634da6a7 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceAsyncClient +from google.cloud.aiplatform_v1.services.migration_service import ( + MigrationServiceAsyncClient, +) from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceClient from google.cloud.aiplatform_v1.services.migration_service import pagers from google.cloud.aiplatform_v1.services.migration_service import transports @@ -53,7 +55,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -64,36 +70,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -103,7 +126,7 @@ def test_migration_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_client_get_transport_class(): @@ -117,29 +140,44 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) -def test_migration_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -155,7 +193,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -171,7 +209,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -191,13 +229,15 @@ def test_migration_service_client_client_options(client_class, transport_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -210,26 +250,62 @@ def test_migration_service_client_client_options(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -252,10 +328,18 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -276,9 +360,14 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -292,16 +381,23 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -314,16 +410,24 @@ def test_migration_service_client_client_options_scopes(client_class, transport_ client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -338,10 +442,12 @@ def test_migration_service_client_client_options_credentials_file(client_class, def test_migration_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -354,10 +460,12 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -366,12 +474,11 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_migratable_resources(request) @@ -386,7 +493,7 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_migratable_resources_from_dict(): @@ -397,25 +504,27 @@ def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.SearchMigratableResourcesRequest() + @pytest.mark.asyncio -async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): +async def test_search_migratable_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -424,12 +533,14 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.search_migratable_resources(request) @@ -442,7 +553,7 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -451,19 +562,17 @@ async def test_search_migratable_resources_async_from_dict(): def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -475,10 +584,7 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -490,13 +596,15 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) await client.search_migratable_resources(request) @@ -507,49 +615,39 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources( - parent='parent_value', - ) + client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) @@ -561,24 +659,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources( - parent='parent_value', - ) + response = await client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -591,20 +689,17 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -613,17 +708,14 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -636,9 +728,7 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.search_migratable_resources(request={}) @@ -646,18 +736,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in results) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + def test_search_migratable_resources_pages(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -666,17 +756,14 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -687,19 +774,20 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -708,17 +796,14 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -729,25 +814,27 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in responses) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -756,17 +843,14 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -779,14 +863,15 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -795,10 +880,10 @@ def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_migrate_resources(request) @@ -820,25 +905,27 @@ def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.BatchMigrateResourcesRequest() + @pytest.mark.asyncio -async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): +async def test_batch_migrate_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.BatchMigrateResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -847,11 +934,11 @@ async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_migrate_resources(request) @@ -872,20 +959,18 @@ async def test_batch_migrate_resources_async_from_dict(): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_migrate_resources(request) @@ -896,10 +981,7 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -911,13 +993,15 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_migrate_resources(request) @@ -928,29 +1012,30 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -958,23 +1043,33 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -986,19 +1081,25 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -1006,9 +1107,15 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] @pytest.mark.asyncio @@ -1022,8 +1129,14 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -1034,8 +1147,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1054,8 +1166,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1083,13 +1194,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1097,13 +1211,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MigrationServiceGrpcTransport, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) def test_migration_service_base_transport_error(): @@ -1111,13 +1220,15 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1126,9 +1237,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'search_migratable_resources', - 'batch_migrate_resources', - ) + "search_migratable_resources", + "batch_migrate_resources", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1141,23 +1252,28 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1166,11 +1282,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1178,19 +1294,25 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1199,15 +1321,13 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1222,38 +1342,40 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_migration_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1261,12 +1383,11 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1275,12 +1396,22 @@ def test_migration_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1289,7 +1420,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1305,9 +1436,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1321,17 +1450,23 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1348,9 +1483,7 @@ def test_migration_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1363,16 +1496,12 @@ def test_migration_service_transport_channel_mtls_with_adc( def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1380,16 +1509,12 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1400,17 +1525,20 @@ def test_annotated_dataset_path(): dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) - actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", - + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1418,20 +1546,22 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", - + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1439,22 +1569,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "scallop" location = "abalone" dataset = "squid" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", - "dataset": "octopus", - + "project": "clam", + "location": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1462,22 +1594,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "oyster" location = "nudibranch" dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", - "dataset": "nautilus", - + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1485,22 +1619,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - + "project": "clam", + "location": "whelk", + "model": "octopus", } path = MigrationServiceClient.model_path(**expected) @@ -1508,22 +1644,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", - + "project": "mussel", + "location": "winkle", + "model": "nautilus", } path = MigrationServiceClient.model_path(**expected) @@ -1531,22 +1669,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", - + "project": "clam", + "model": "whelk", + "version": "octopus", } path = MigrationServiceClient.version_path(**expected) @@ -1554,18 +1694,20 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", - + "billing_account": "nudibranch", } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1573,18 +1715,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", - + "folder": "mussel", } path = MigrationServiceClient.common_folder_path(**expected) @@ -1592,18 +1734,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", - + "organization": "nautilus", } path = MigrationServiceClient.common_organization_path(**expected) @@ -1611,18 +1753,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", - + "project": "abalone", } path = MigrationServiceClient.common_project_path(**expected) @@ -1630,20 +1772,22 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", - + "project": "whelk", + "location": "octopus", } path = MigrationServiceClient.common_location_path(**expected) @@ -1655,17 +1799,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py index 3c99da7fac..e3f19d0271 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py @@ -32,9 +32,15 @@ from google.api_core import grpc_helpers_async from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import FeaturestoreOnlineServingServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import FeaturestoreOnlineServingServiceClient -from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import transports +from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import ( + FeaturestoreOnlineServingServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import ( + FeaturestoreOnlineServingServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import ( + transports, +) from google.cloud.aiplatform_v1beta1.types import feature_selector from google.cloud.aiplatform_v1beta1.types import featurestore_online_service from google.oauth2 import service_account @@ -49,7 +55,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -59,37 +69,74 @@ def test__get_default_mtls_endpoint(): sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" non_googleapi = "api.example.com" - assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(None) is None - assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(None) is None + ) + assert ( + FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint( + api_mtls_endpoint + ) + == api_mtls_endpoint + ) + assert ( + FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint( + sandbox_endpoint + ) + == sandbox_mtls_endpoint + ) + assert ( + FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint( + sandbox_mtls_endpoint + ) + == sandbox_mtls_endpoint + ) + assert ( + FeaturestoreOnlineServingServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - FeaturestoreOnlineServingServiceClient, - FeaturestoreOnlineServingServiceAsyncClient, -]) -def test_featurestore_online_serving_service_client_from_service_account_info(client_class): +@pytest.mark.parametrize( + "client_class", + [ + FeaturestoreOnlineServingServiceClient, + FeaturestoreOnlineServingServiceAsyncClient, + ], +) +def test_featurestore_online_serving_service_client_from_service_account_info( + client_class, +): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - FeaturestoreOnlineServingServiceClient, - FeaturestoreOnlineServingServiceAsyncClient, -]) -def test_featurestore_online_serving_service_client_from_service_account_file(client_class): +@pytest.mark.parametrize( + "client_class", + [ + FeaturestoreOnlineServingServiceClient, + FeaturestoreOnlineServingServiceAsyncClient, + ], +) +def test_featurestore_online_serving_service_client_from_service_account_file( + client_class, +): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -99,7 +146,7 @@ def test_featurestore_online_serving_service_client_from_service_account_file(cl assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_featurestore_online_serving_service_client_get_transport_class(): @@ -113,29 +160,52 @@ def test_featurestore_online_serving_service_client_get_transport_class(): assert transport == transports.FeaturestoreOnlineServingServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc"), - (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(FeaturestoreOnlineServingServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceClient)) -@mock.patch.object(FeaturestoreOnlineServingServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient)) -def test_featurestore_online_serving_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + FeaturestoreOnlineServingServiceClient, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + "grpc", + ), + ( + FeaturestoreOnlineServingServiceAsyncClient, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + FeaturestoreOnlineServingServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceClient), +) +@mock.patch.object( + FeaturestoreOnlineServingServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient), +) +def test_featurestore_online_serving_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(FeaturestoreOnlineServingServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object( + FeaturestoreOnlineServingServiceClient, "get_transport_class" + ) as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(FeaturestoreOnlineServingServiceClient, 'get_transport_class') as gtc: + with mock.patch.object( + FeaturestoreOnlineServingServiceClient, "get_transport_class" + ) as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -151,7 +221,7 @@ def test_featurestore_online_serving_service_client_client_options(client_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -167,7 +237,7 @@ def test_featurestore_online_serving_service_client_client_options(client_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -187,13 +257,15 @@ def test_featurestore_online_serving_service_client_client_options(client_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -206,26 +278,62 @@ def test_featurestore_online_serving_service_client_client_options(client_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc", "true"), - (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc", "false"), - (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(FeaturestoreOnlineServingServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceClient)) -@mock.patch.object(FeaturestoreOnlineServingServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + FeaturestoreOnlineServingServiceClient, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + "grpc", + "true", + ), + ( + FeaturestoreOnlineServingServiceAsyncClient, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + FeaturestoreOnlineServingServiceClient, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + "grpc", + "false", + ), + ( + FeaturestoreOnlineServingServiceAsyncClient, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + FeaturestoreOnlineServingServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceClient), +) +@mock.patch.object( + FeaturestoreOnlineServingServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreOnlineServingServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_featurestore_online_serving_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_featurestore_online_serving_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -248,10 +356,18 @@ def test_featurestore_online_serving_service_client_mtls_env_auto(client_class, # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -272,9 +388,14 @@ def test_featurestore_online_serving_service_client_mtls_env_auto(client_class, ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -288,16 +409,27 @@ def test_featurestore_online_serving_service_client_mtls_env_auto(client_class, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc"), - (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_featurestore_online_serving_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + FeaturestoreOnlineServingServiceClient, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + "grpc", + ), + ( + FeaturestoreOnlineServingServiceAsyncClient, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_featurestore_online_serving_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -310,16 +442,28 @@ def test_featurestore_online_serving_service_client_client_options_scopes(client client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (FeaturestoreOnlineServingServiceClient, transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc"), - (FeaturestoreOnlineServingServiceAsyncClient, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_featurestore_online_serving_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + FeaturestoreOnlineServingServiceClient, + transports.FeaturestoreOnlineServingServiceGrpcTransport, + "grpc", + ), + ( + FeaturestoreOnlineServingServiceAsyncClient, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_featurestore_online_serving_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -334,10 +478,12 @@ def test_featurestore_online_serving_service_client_client_options_credentials_f def test_featurestore_online_serving_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = FeaturestoreOnlineServingServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -350,10 +496,12 @@ def test_featurestore_online_serving_service_client_client_options_from_dict(): ) -def test_read_feature_values(transport: str = 'grpc', request_type=featurestore_online_service.ReadFeatureValuesRequest): +def test_read_feature_values( + transport: str = "grpc", + request_type=featurestore_online_service.ReadFeatureValuesRequest, +): client = FeaturestoreOnlineServingServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -362,11 +510,10 @@ def test_read_feature_values(transport: str = 'grpc', request_type=featurestore_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.read_feature_values), - '__call__') as call: + type(client.transport.read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = featurestore_online_service.ReadFeatureValuesResponse( - ) + call.return_value = featurestore_online_service.ReadFeatureValuesResponse() response = client.read_feature_values(request) @@ -389,25 +536,27 @@ def test_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreOnlineServingServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.read_feature_values), - '__call__') as call: + type(client.transport.read_feature_values), "__call__" + ) as call: client.read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_online_service.ReadFeatureValuesRequest() + @pytest.mark.asyncio -async def test_read_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_online_service.ReadFeatureValuesRequest): +async def test_read_feature_values_async( + transport: str = "grpc_asyncio", + request_type=featurestore_online_service.ReadFeatureValuesRequest, +): client = FeaturestoreOnlineServingServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -416,11 +565,12 @@ async def test_read_feature_values_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.read_feature_values), - '__call__') as call: + type(client.transport.read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_online_service.ReadFeatureValuesResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_online_service.ReadFeatureValuesResponse() + ) response = await client.read_feature_values(request) @@ -447,12 +597,12 @@ def test_read_feature_values_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_online_service.ReadFeatureValuesRequest() - request.entity_type = 'entity_type/value' + request.entity_type = "entity_type/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.read_feature_values), - '__call__') as call: + type(client.transport.read_feature_values), "__call__" + ) as call: call.return_value = featurestore_online_service.ReadFeatureValuesResponse() client.read_feature_values(request) @@ -464,10 +614,7 @@ def test_read_feature_values_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type=entity_type/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] @pytest.mark.asyncio @@ -479,13 +626,15 @@ async def test_read_feature_values_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_online_service.ReadFeatureValuesRequest() - request.entity_type = 'entity_type/value' + request.entity_type = "entity_type/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.read_feature_values), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_online_service.ReadFeatureValuesResponse()) + type(client.transport.read_feature_values), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_online_service.ReadFeatureValuesResponse() + ) await client.read_feature_values(request) @@ -496,10 +645,7 @@ async def test_read_feature_values_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type=entity_type/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] def test_read_feature_values_flattened(): @@ -509,23 +655,21 @@ def test_read_feature_values_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.read_feature_values), - '__call__') as call: + type(client.transport.read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_online_service.ReadFeatureValuesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.read_feature_values( - entity_type='entity_type_value', - ) + client.read_feature_values(entity_type="entity_type_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].entity_type == 'entity_type_value' + assert args[0].entity_type == "entity_type_value" def test_read_feature_values_flattened_error(): @@ -538,7 +682,7 @@ def test_read_feature_values_flattened_error(): with pytest.raises(ValueError): client.read_feature_values( featurestore_online_service.ReadFeatureValuesRequest(), - entity_type='entity_type_value', + entity_type="entity_type_value", ) @@ -550,24 +694,24 @@ async def test_read_feature_values_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.read_feature_values), - '__call__') as call: + type(client.transport.read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_online_service.ReadFeatureValuesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_online_service.ReadFeatureValuesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_online_service.ReadFeatureValuesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.read_feature_values( - entity_type='entity_type_value', - ) + response = await client.read_feature_values(entity_type="entity_type_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].entity_type == 'entity_type_value' + assert args[0].entity_type == "entity_type_value" @pytest.mark.asyncio @@ -581,14 +725,16 @@ async def test_read_feature_values_flattened_error_async(): with pytest.raises(ValueError): await client.read_feature_values( featurestore_online_service.ReadFeatureValuesRequest(), - entity_type='entity_type_value', + entity_type="entity_type_value", ) -def test_streaming_read_feature_values(transport: str = 'grpc', request_type=featurestore_online_service.StreamingReadFeatureValuesRequest): +def test_streaming_read_feature_values( + transport: str = "grpc", + request_type=featurestore_online_service.StreamingReadFeatureValuesRequest, +): client = FeaturestoreOnlineServingServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -597,10 +743,12 @@ def test_streaming_read_feature_values(transport: str = 'grpc', request_type=fea # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.streaming_read_feature_values), - '__call__') as call: + type(client.transport.streaming_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + call.return_value = iter( + [featurestore_online_service.ReadFeatureValuesResponse()] + ) response = client.streaming_read_feature_values(request) @@ -608,11 +756,15 @@ def test_streaming_read_feature_values(transport: str = 'grpc', request_type=fea assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + assert ( + args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + ) # Establish that the response is the type that we expect. for message in response: - assert isinstance(message, featurestore_online_service.ReadFeatureValuesResponse) + assert isinstance( + message, featurestore_online_service.ReadFeatureValuesResponse + ) def test_streaming_read_feature_values_from_dict(): @@ -623,25 +775,29 @@ def test_streaming_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreOnlineServingServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.streaming_read_feature_values), - '__call__') as call: + type(client.transport.streaming_read_feature_values), "__call__" + ) as call: client.streaming_read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + assert ( + args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + ) + @pytest.mark.asyncio -async def test_streaming_read_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_online_service.StreamingReadFeatureValuesRequest): +async def test_streaming_read_feature_values_async( + transport: str = "grpc_asyncio", + request_type=featurestore_online_service.StreamingReadFeatureValuesRequest, +): client = FeaturestoreOnlineServingServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -650,11 +806,13 @@ async def test_streaming_read_feature_values_async(transport: str = 'grpc_asynci # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.streaming_read_feature_values), - '__call__') as call: + type(client.transport.streaming_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) - call.return_value.read = mock.AsyncMock(side_effect=[featurestore_online_service.ReadFeatureValuesResponse()]) + call.return_value.read = mock.AsyncMock( + side_effect=[featurestore_online_service.ReadFeatureValuesResponse()] + ) response = await client.streaming_read_feature_values(request) @@ -662,7 +820,9 @@ async def test_streaming_read_feature_values_async(transport: str = 'grpc_asynci assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + assert ( + args[0] == featurestore_online_service.StreamingReadFeatureValuesRequest() + ) # Establish that the response is the type that we expect. message = await response.read() @@ -682,13 +842,15 @@ def test_streaming_read_feature_values_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_online_service.StreamingReadFeatureValuesRequest() - request.entity_type = 'entity_type/value' + request.entity_type = "entity_type/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.streaming_read_feature_values), - '__call__') as call: - call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + type(client.transport.streaming_read_feature_values), "__call__" + ) as call: + call.return_value = iter( + [featurestore_online_service.ReadFeatureValuesResponse()] + ) client.streaming_read_feature_values(request) @@ -699,10 +861,7 @@ def test_streaming_read_feature_values_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type=entity_type/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] @pytest.mark.asyncio @@ -714,14 +873,16 @@ async def test_streaming_read_feature_values_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_online_service.StreamingReadFeatureValuesRequest() - request.entity_type = 'entity_type/value' + request.entity_type = "entity_type/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.streaming_read_feature_values), - '__call__') as call: + type(client.transport.streaming_read_feature_values), "__call__" + ) as call: call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) - call.return_value.read = mock.AsyncMock(side_effect=[featurestore_online_service.ReadFeatureValuesResponse()]) + call.return_value.read = mock.AsyncMock( + side_effect=[featurestore_online_service.ReadFeatureValuesResponse()] + ) await client.streaming_read_feature_values(request) @@ -732,10 +893,7 @@ async def test_streaming_read_feature_values_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type=entity_type/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] def test_streaming_read_feature_values_flattened(): @@ -745,23 +903,23 @@ def test_streaming_read_feature_values_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.streaming_read_feature_values), - '__call__') as call: + type(client.transport.streaming_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + call.return_value = iter( + [featurestore_online_service.ReadFeatureValuesResponse()] + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.streaming_read_feature_values( - entity_type='entity_type_value', - ) + client.streaming_read_feature_values(entity_type="entity_type_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].entity_type == 'entity_type_value' + assert args[0].entity_type == "entity_type_value" def test_streaming_read_feature_values_flattened_error(): @@ -774,7 +932,7 @@ def test_streaming_read_feature_values_flattened_error(): with pytest.raises(ValueError): client.streaming_read_feature_values( featurestore_online_service.StreamingReadFeatureValuesRequest(), - entity_type='entity_type_value', + entity_type="entity_type_value", ) @@ -786,16 +944,18 @@ async def test_streaming_read_feature_values_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.streaming_read_feature_values), - '__call__') as call: + type(client.transport.streaming_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = iter([featurestore_online_service.ReadFeatureValuesResponse()]) + call.return_value = iter( + [featurestore_online_service.ReadFeatureValuesResponse()] + ) call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.streaming_read_feature_values( - entity_type='entity_type_value', + entity_type="entity_type_value", ) # Establish that the underlying call was made with the expected @@ -803,7 +963,7 @@ async def test_streaming_read_feature_values_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].entity_type == 'entity_type_value' + assert args[0].entity_type == "entity_type_value" @pytest.mark.asyncio @@ -817,7 +977,7 @@ async def test_streaming_read_feature_values_flattened_error_async(): with pytest.raises(ValueError): await client.streaming_read_feature_values( featurestore_online_service.StreamingReadFeatureValuesRequest(), - entity_type='entity_type_value', + entity_type="entity_type_value", ) @@ -828,8 +988,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = FeaturestoreOnlineServingServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -848,8 +1007,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = FeaturestoreOnlineServingServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -877,13 +1035,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.FeaturestoreOnlineServingServiceGrpcTransport, - transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreOnlineServingServiceGrpcTransport, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -895,8 +1056,7 @@ def test_transport_grpc_default(): credentials=credentials.AnonymousCredentials(), ) assert isinstance( - client.transport, - transports.FeaturestoreOnlineServingServiceGrpcTransport, + client.transport, transports.FeaturestoreOnlineServingServiceGrpcTransport, ) @@ -905,13 +1065,15 @@ def test_featurestore_online_serving_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.FeaturestoreOnlineServingServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_featurestore_online_serving_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.FeaturestoreOnlineServingServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -920,9 +1082,9 @@ def test_featurestore_online_serving_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'read_feature_values', - 'streaming_read_feature_values', - ) + "read_feature_values", + "streaming_read_feature_values", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -930,23 +1092,28 @@ def test_featurestore_online_serving_service_base_transport(): def test_featurestore_online_serving_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.FeaturestoreOnlineServingServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_featurestore_online_serving_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.FeaturestoreOnlineServingServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.FeaturestoreOnlineServingServiceTransport() @@ -955,11 +1122,11 @@ def test_featurestore_online_serving_service_base_transport_with_adc(): def test_featurestore_online_serving_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) FeaturestoreOnlineServingServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -967,18 +1134,26 @@ def test_featurestore_online_serving_service_auth_adc(): def test_featurestore_online_serving_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.FeaturestoreOnlineServingServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.FeaturestoreOnlineServingServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.FeaturestoreOnlineServingServiceGrpcTransport, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreOnlineServingServiceGrpcTransport, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + ], +) def test_featurestore_online_serving_service_grpc_transport_client_cert_source_for_mtls( - transport_class + transport_class, ): cred = credentials.AnonymousCredentials() @@ -988,15 +1163,13 @@ def test_featurestore_online_serving_service_grpc_transport_client_cert_source_f transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1011,38 +1184,40 @@ def test_featurestore_online_serving_service_grpc_transport_client_cert_source_f with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_featurestore_online_serving_service_host_no_port(): client = FeaturestoreOnlineServingServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_featurestore_online_serving_service_host_with_port(): client = FeaturestoreOnlineServingServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_featurestore_online_serving_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.FeaturestoreOnlineServingServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1050,12 +1225,11 @@ def test_featurestore_online_serving_service_grpc_transport_channel(): def test_featurestore_online_serving_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1064,12 +1238,22 @@ def test_featurestore_online_serving_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.FeaturestoreOnlineServingServiceGrpcTransport, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreOnlineServingServiceGrpcTransport, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + ], +) def test_featurestore_online_serving_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1078,7 +1262,7 @@ def test_featurestore_online_serving_service_transport_channel_mtls_with_client_ cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1094,9 +1278,7 @@ def test_featurestore_online_serving_service_transport_channel_mtls_with_client_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1110,9 +1292,15 @@ def test_featurestore_online_serving_service_transport_channel_mtls_with_client_ # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.FeaturestoreOnlineServingServiceGrpcTransport, transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreOnlineServingServiceGrpcTransport, + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + ], +) def test_featurestore_online_serving_service_transport_channel_mtls_with_adc( - transport_class + transport_class, ): mock_ssl_cred = mock.Mock() with mock.patch.multiple( @@ -1120,7 +1308,9 @@ def test_featurestore_online_serving_service_transport_channel_mtls_with_adc( __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1137,9 +1327,7 @@ def test_featurestore_online_serving_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1156,18 +1344,24 @@ def test_entity_type_path(): featurestore = "whelk" entity_type = "octopus" - expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) - actual = FeaturestoreOnlineServingServiceClient.entity_type_path(project, location, featurestore, entity_type) + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format( + project=project, + location=location, + featurestore=featurestore, + entity_type=entity_type, + ) + actual = FeaturestoreOnlineServingServiceClient.entity_type_path( + project, location, featurestore, entity_type + ) assert expected == actual def test_parse_entity_type_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "featurestore": "cuttlefish", - "entity_type": "mussel", - + "project": "oyster", + "location": "nudibranch", + "featurestore": "cuttlefish", + "entity_type": "mussel", } path = FeaturestoreOnlineServingServiceClient.entity_type_path(**expected) @@ -1175,37 +1369,45 @@ def test_parse_entity_type_path(): actual = FeaturestoreOnlineServingServiceClient.parse_entity_type_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) - actual = FeaturestoreOnlineServingServiceClient.common_billing_account_path(billing_account) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = FeaturestoreOnlineServingServiceClient.common_billing_account_path( + billing_account + ) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } - path = FeaturestoreOnlineServingServiceClient.common_billing_account_path(**expected) + path = FeaturestoreOnlineServingServiceClient.common_billing_account_path( + **expected + ) # Check that the path construction is reversible. - actual = FeaturestoreOnlineServingServiceClient.parse_common_billing_account_path(path) + actual = FeaturestoreOnlineServingServiceClient.parse_common_billing_account_path( + path + ) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = FeaturestoreOnlineServingServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = FeaturestoreOnlineServingServiceClient.common_folder_path(**expected) @@ -1213,18 +1415,20 @@ def test_parse_common_folder_path(): actual = FeaturestoreOnlineServingServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) - actual = FeaturestoreOnlineServingServiceClient.common_organization_path(organization) + expected = "organizations/{organization}".format(organization=organization,) + actual = FeaturestoreOnlineServingServiceClient.common_organization_path( + organization + ) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = FeaturestoreOnlineServingServiceClient.common_organization_path(**expected) @@ -1232,18 +1436,18 @@ def test_parse_common_organization_path(): actual = FeaturestoreOnlineServingServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = FeaturestoreOnlineServingServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = FeaturestoreOnlineServingServiceClient.common_project_path(**expected) @@ -1251,20 +1455,24 @@ def test_parse_common_project_path(): actual = FeaturestoreOnlineServingServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) - actual = FeaturestoreOnlineServingServiceClient.common_location_path(project, location) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = FeaturestoreOnlineServingServiceClient.common_location_path( + project, location + ) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = FeaturestoreOnlineServingServiceClient.common_location_path(**expected) @@ -1276,17 +1484,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.FeaturestoreOnlineServingServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.FeaturestoreOnlineServingServiceTransport, "_prep_wrapped_messages" + ) as prep: client = FeaturestoreOnlineServingServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.FeaturestoreOnlineServingServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.FeaturestoreOnlineServingServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = FeaturestoreOnlineServingServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py index 74c27eb5a3..f5e67013c6 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.featurestore_service import FeaturestoreServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.featurestore_service import FeaturestoreServiceClient +from google.cloud.aiplatform_v1beta1.services.featurestore_service import ( + FeaturestoreServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.featurestore_service import ( + FeaturestoreServiceClient, +) from google.cloud.aiplatform_v1beta1.services.featurestore_service import pagers from google.cloud.aiplatform_v1beta1.services.featurestore_service import transports from google.cloud.aiplatform_v1beta1.types import entity_type @@ -66,7 +70,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -77,36 +85,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert FeaturestoreServiceClient._get_default_mtls_endpoint(None) is None - assert FeaturestoreServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert FeaturestoreServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert FeaturestoreServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert FeaturestoreServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert FeaturestoreServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + FeaturestoreServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + FeaturestoreServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + FeaturestoreServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + FeaturestoreServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + FeaturestoreServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - FeaturestoreServiceClient, - FeaturestoreServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [FeaturestoreServiceClient, FeaturestoreServiceAsyncClient,] +) def test_featurestore_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - FeaturestoreServiceClient, - FeaturestoreServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [FeaturestoreServiceClient, FeaturestoreServiceAsyncClient,] +) def test_featurestore_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -116,7 +141,7 @@ def test_featurestore_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_featurestore_service_client_get_transport_class(): @@ -130,29 +155,48 @@ def test_featurestore_service_client_get_transport_class(): assert transport == transports.FeaturestoreServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc"), - (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(FeaturestoreServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceClient)) -@mock.patch.object(FeaturestoreServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceAsyncClient)) -def test_featurestore_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + FeaturestoreServiceClient, + transports.FeaturestoreServiceGrpcTransport, + "grpc", + ), + ( + FeaturestoreServiceAsyncClient, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + FeaturestoreServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceClient), +) +@mock.patch.object( + FeaturestoreServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceAsyncClient), +) +def test_featurestore_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(FeaturestoreServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(FeaturestoreServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(FeaturestoreServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(FeaturestoreServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -168,7 +212,7 @@ def test_featurestore_service_client_client_options(client_class, transport_clas # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -184,7 +228,7 @@ def test_featurestore_service_client_client_options(client_class, transport_clas # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -204,13 +248,15 @@ def test_featurestore_service_client_client_options(client_class, transport_clas client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -223,26 +269,62 @@ def test_featurestore_service_client_client_options(client_class, transport_clas client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc", "true"), - (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc", "false"), - (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(FeaturestoreServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceClient)) -@mock.patch.object(FeaturestoreServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FeaturestoreServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + FeaturestoreServiceClient, + transports.FeaturestoreServiceGrpcTransport, + "grpc", + "true", + ), + ( + FeaturestoreServiceAsyncClient, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + FeaturestoreServiceClient, + transports.FeaturestoreServiceGrpcTransport, + "grpc", + "false", + ), + ( + FeaturestoreServiceAsyncClient, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + FeaturestoreServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceClient), +) +@mock.patch.object( + FeaturestoreServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FeaturestoreServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_featurestore_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_featurestore_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -265,10 +347,18 @@ def test_featurestore_service_client_mtls_env_auto(client_class, transport_class # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -289,9 +379,14 @@ def test_featurestore_service_client_mtls_env_auto(client_class, transport_class ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -305,16 +400,27 @@ def test_featurestore_service_client_mtls_env_auto(client_class, transport_class ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc"), - (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_featurestore_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + FeaturestoreServiceClient, + transports.FeaturestoreServiceGrpcTransport, + "grpc", + ), + ( + FeaturestoreServiceAsyncClient, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_featurestore_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -327,16 +433,28 @@ def test_featurestore_service_client_client_options_scopes(client_class, transpo client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (FeaturestoreServiceClient, transports.FeaturestoreServiceGrpcTransport, "grpc"), - (FeaturestoreServiceAsyncClient, transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_featurestore_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + FeaturestoreServiceClient, + transports.FeaturestoreServiceGrpcTransport, + "grpc", + ), + ( + FeaturestoreServiceAsyncClient, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_featurestore_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -351,10 +469,12 @@ def test_featurestore_service_client_client_options_credentials_file(client_clas def test_featurestore_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = FeaturestoreServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -367,10 +487,11 @@ def test_featurestore_service_client_client_options_from_dict(): ) -def test_create_featurestore(transport: str = 'grpc', request_type=featurestore_service.CreateFeaturestoreRequest): +def test_create_featurestore( + transport: str = "grpc", request_type=featurestore_service.CreateFeaturestoreRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -379,10 +500,10 @@ def test_create_featurestore(transport: str = 'grpc', request_type=featurestore_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_featurestore), - '__call__') as call: + type(client.transport.create_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_featurestore(request) @@ -404,25 +525,27 @@ def test_create_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_featurestore), - '__call__') as call: + type(client.transport.create_featurestore), "__call__" + ) as call: client.create_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.CreateFeaturestoreRequest() + @pytest.mark.asyncio -async def test_create_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.CreateFeaturestoreRequest): +async def test_create_featurestore_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.CreateFeaturestoreRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -431,11 +554,11 @@ async def test_create_featurestore_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_featurestore), - '__call__') as call: + type(client.transport.create_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_featurestore(request) @@ -456,20 +579,18 @@ async def test_create_featurestore_async_from_dict(): def test_create_featurestore_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.CreateFeaturestoreRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_featurestore), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_featurestore), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_featurestore(request) @@ -480,10 +601,7 @@ def test_create_featurestore_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -495,13 +613,15 @@ async def test_create_featurestore_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.CreateFeaturestoreRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_featurestore), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_featurestore), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_featurestore(request) @@ -512,29 +632,24 @@ async def test_create_featurestore_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_featurestore_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_featurestore), - '__call__') as call: + type(client.transport.create_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_featurestore( - parent='parent_value', - featurestore=gca_featurestore.Featurestore(name='name_value'), + parent="parent_value", + featurestore=gca_featurestore.Featurestore(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -542,23 +657,21 @@ def test_create_featurestore_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + assert args[0].featurestore == gca_featurestore.Featurestore(name="name_value") def test_create_featurestore_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_featurestore( featurestore_service.CreateFeaturestoreRequest(), - parent='parent_value', - featurestore=gca_featurestore.Featurestore(name='name_value'), + parent="parent_value", + featurestore=gca_featurestore.Featurestore(name="name_value"), ) @@ -570,19 +683,19 @@ async def test_create_featurestore_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_featurestore), - '__call__') as call: + type(client.transport.create_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_featurestore( - parent='parent_value', - featurestore=gca_featurestore.Featurestore(name='name_value'), + parent="parent_value", + featurestore=gca_featurestore.Featurestore(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -590,9 +703,9 @@ async def test_create_featurestore_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + assert args[0].featurestore == gca_featurestore.Featurestore(name="name_value") @pytest.mark.asyncio @@ -606,15 +719,16 @@ async def test_create_featurestore_flattened_error_async(): with pytest.raises(ValueError): await client.create_featurestore( featurestore_service.CreateFeaturestoreRequest(), - parent='parent_value', - featurestore=gca_featurestore.Featurestore(name='name_value'), + parent="parent_value", + featurestore=gca_featurestore.Featurestore(name="name_value"), ) -def test_get_featurestore(transport: str = 'grpc', request_type=featurestore_service.GetFeaturestoreRequest): +def test_get_featurestore( + transport: str = "grpc", request_type=featurestore_service.GetFeaturestoreRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -622,19 +736,13 @@ def test_get_featurestore(transport: str = 'grpc', request_type=featurestore_ser request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_featurestore), - '__call__') as call: + with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore.Featurestore( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + etag="etag_value", state=featurestore.Featurestore.State.STABLE, - ) response = client.get_featurestore(request) @@ -649,11 +757,11 @@ def test_get_featurestore(transport: str = 'grpc', request_type=featurestore_ser assert isinstance(response, featurestore.Featurestore) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == featurestore.Featurestore.State.STABLE @@ -666,25 +774,25 @@ def test_get_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_featurestore), - '__call__') as call: + with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: client.get_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.GetFeaturestoreRequest() + @pytest.mark.asyncio -async def test_get_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.GetFeaturestoreRequest): +async def test_get_featurestore_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.GetFeaturestoreRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -692,16 +800,16 @@ async def test_get_featurestore_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_featurestore), - '__call__') as call: + with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore.Featurestore( - name='name_value', - display_name='display_name_value', - etag='etag_value', - state=featurestore.Featurestore.State.STABLE, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore.Featurestore( + name="name_value", + display_name="display_name_value", + etag="etag_value", + state=featurestore.Featurestore.State.STABLE, + ) + ) response = await client.get_featurestore(request) @@ -714,11 +822,11 @@ async def test_get_featurestore_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, featurestore.Featurestore) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == featurestore.Featurestore.State.STABLE @@ -729,19 +837,15 @@ async def test_get_featurestore_async_from_dict(): def test_get_featurestore_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.GetFeaturestoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_featurestore), - '__call__') as call: + with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: call.return_value = featurestore.Featurestore() client.get_featurestore(request) @@ -753,10 +857,7 @@ def test_get_featurestore_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -768,13 +869,13 @@ async def test_get_featurestore_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.GetFeaturestoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_featurestore), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore.Featurestore()) + with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore.Featurestore() + ) await client.get_featurestore(request) @@ -785,49 +886,37 @@ async def test_get_featurestore_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_featurestore_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_featurestore), - '__call__') as call: + with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore.Featurestore() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_featurestore( - name='name_value', - ) + client.get_featurestore(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_featurestore_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_featurestore( - featurestore_service.GetFeaturestoreRequest(), - name='name_value', + featurestore_service.GetFeaturestoreRequest(), name="name_value", ) @@ -838,25 +927,23 @@ async def test_get_featurestore_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_featurestore), - '__call__') as call: + with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore.Featurestore() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore.Featurestore()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore.Featurestore() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_featurestore( - name='name_value', - ) + response = await client.get_featurestore(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -869,15 +956,15 @@ async def test_get_featurestore_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_featurestore( - featurestore_service.GetFeaturestoreRequest(), - name='name_value', + featurestore_service.GetFeaturestoreRequest(), name="name_value", ) -def test_list_featurestores(transport: str = 'grpc', request_type=featurestore_service.ListFeaturestoresRequest): +def test_list_featurestores( + transport: str = "grpc", request_type=featurestore_service.ListFeaturestoresRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -886,12 +973,11 @@ def test_list_featurestores(transport: str = 'grpc', request_type=featurestore_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListFeaturestoresResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_featurestores(request) @@ -906,7 +992,7 @@ def test_list_featurestores(transport: str = 'grpc', request_type=featurestore_s assert isinstance(response, pagers.ListFeaturestoresPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_featurestores_from_dict(): @@ -917,25 +1003,27 @@ def test_list_featurestores_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: client.list_featurestores() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.ListFeaturestoresRequest() + @pytest.mark.asyncio -async def test_list_featurestores_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ListFeaturestoresRequest): +async def test_list_featurestores_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.ListFeaturestoresRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -944,12 +1032,14 @@ async def test_list_featurestores_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturestoresResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListFeaturestoresResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_featurestores(request) @@ -962,7 +1052,7 @@ async def test_list_featurestores_async(transport: str = 'grpc_asyncio', request # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListFeaturestoresAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -971,19 +1061,17 @@ async def test_list_featurestores_async_from_dict(): def test_list_featurestores_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ListFeaturestoresRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: call.return_value = featurestore_service.ListFeaturestoresResponse() client.list_featurestores(request) @@ -995,10 +1083,7 @@ def test_list_featurestores_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1010,13 +1095,15 @@ async def test_list_featurestores_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ListFeaturestoresRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturestoresResponse()) + type(client.transport.list_featurestores), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListFeaturestoresResponse() + ) await client.list_featurestores(request) @@ -1027,49 +1114,39 @@ async def test_list_featurestores_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_featurestores_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListFeaturestoresResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_featurestores( - parent='parent_value', - ) + client.list_featurestores(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_featurestores_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_featurestores( - featurestore_service.ListFeaturestoresRequest(), - parent='parent_value', + featurestore_service.ListFeaturestoresRequest(), parent="parent_value", ) @@ -1081,24 +1158,24 @@ async def test_list_featurestores_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListFeaturestoresResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturestoresResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListFeaturestoresResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_featurestores( - parent='parent_value', - ) + response = await client.list_featurestores(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -1111,20 +1188,17 @@ async def test_list_featurestores_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_featurestores( - featurestore_service.ListFeaturestoresRequest(), - parent='parent_value', + featurestore_service.ListFeaturestoresRequest(), parent="parent_value", ) def test_list_featurestores_pager(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturestoresResponse( @@ -1133,17 +1207,13 @@ def test_list_featurestores_pager(): featurestore.Featurestore(), featurestore.Featurestore(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[], - next_page_token='def', + featurestores=[], next_page_token="def", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[ - featurestore.Featurestore(), - ], - next_page_token='ghi', + featurestores=[featurestore.Featurestore(),], next_page_token="ghi", ), featurestore_service.ListFeaturestoresResponse( featurestores=[ @@ -1156,9 +1226,7 @@ def test_list_featurestores_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_featurestores(request={}) @@ -1166,18 +1234,16 @@ def test_list_featurestores_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, featurestore.Featurestore) - for i in results) + assert all(isinstance(i, featurestore.Featurestore) for i in results) + def test_list_featurestores_pages(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__') as call: + type(client.transport.list_featurestores), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturestoresResponse( @@ -1186,17 +1252,13 @@ def test_list_featurestores_pages(): featurestore.Featurestore(), featurestore.Featurestore(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[], - next_page_token='def', + featurestores=[], next_page_token="def", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[ - featurestore.Featurestore(), - ], - next_page_token='ghi', + featurestores=[featurestore.Featurestore(),], next_page_token="ghi", ), featurestore_service.ListFeaturestoresResponse( featurestores=[ @@ -1207,9 +1269,10 @@ def test_list_featurestores_pages(): RuntimeError, ) pages = list(client.list_featurestores(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_featurestores_async_pager(): client = FeaturestoreServiceAsyncClient( @@ -1218,8 +1281,10 @@ async def test_list_featurestores_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_featurestores), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturestoresResponse( @@ -1228,17 +1293,13 @@ async def test_list_featurestores_async_pager(): featurestore.Featurestore(), featurestore.Featurestore(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[], - next_page_token='def', + featurestores=[], next_page_token="def", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[ - featurestore.Featurestore(), - ], - next_page_token='ghi', + featurestores=[featurestore.Featurestore(),], next_page_token="ghi", ), featurestore_service.ListFeaturestoresResponse( featurestores=[ @@ -1249,14 +1310,14 @@ async def test_list_featurestores_async_pager(): RuntimeError, ) async_pager = await client.list_featurestores(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, featurestore.Featurestore) - for i in responses) + assert all(isinstance(i, featurestore.Featurestore) for i in responses) + @pytest.mark.asyncio async def test_list_featurestores_async_pages(): @@ -1266,8 +1327,10 @@ async def test_list_featurestores_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_featurestores), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_featurestores), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturestoresResponse( @@ -1276,17 +1339,13 @@ async def test_list_featurestores_async_pages(): featurestore.Featurestore(), featurestore.Featurestore(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[], - next_page_token='def', + featurestores=[], next_page_token="def", ), featurestore_service.ListFeaturestoresResponse( - featurestores=[ - featurestore.Featurestore(), - ], - next_page_token='ghi', + featurestores=[featurestore.Featurestore(),], next_page_token="ghi", ), featurestore_service.ListFeaturestoresResponse( featurestores=[ @@ -1299,14 +1358,15 @@ async def test_list_featurestores_async_pages(): pages = [] async for page_ in (await client.list_featurestores(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_featurestore(transport: str = 'grpc', request_type=featurestore_service.UpdateFeaturestoreRequest): +def test_update_featurestore( + transport: str = "grpc", request_type=featurestore_service.UpdateFeaturestoreRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1315,10 +1375,10 @@ def test_update_featurestore(transport: str = 'grpc', request_type=featurestore_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_featurestore), - '__call__') as call: + type(client.transport.update_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.update_featurestore(request) @@ -1340,25 +1400,27 @@ def test_update_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_featurestore), - '__call__') as call: + type(client.transport.update_featurestore), "__call__" + ) as call: client.update_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeaturestoreRequest() + @pytest.mark.asyncio -async def test_update_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.UpdateFeaturestoreRequest): +async def test_update_featurestore_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.UpdateFeaturestoreRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1367,11 +1429,11 @@ async def test_update_featurestore_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_featurestore), - '__call__') as call: + type(client.transport.update_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.update_featurestore(request) @@ -1392,20 +1454,18 @@ async def test_update_featurestore_async_from_dict(): def test_update_featurestore_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.UpdateFeaturestoreRequest() - request.featurestore.name = 'featurestore.name/value' + request.featurestore.name = "featurestore.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_featurestore), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.update_featurestore), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.update_featurestore(request) @@ -1417,9 +1477,9 @@ def test_update_featurestore_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'featurestore.name=featurestore.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "featurestore.name=featurestore.name/value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1431,13 +1491,15 @@ async def test_update_featurestore_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.UpdateFeaturestoreRequest() - request.featurestore.name = 'featurestore.name/value' + request.featurestore.name = "featurestore.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_featurestore), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.update_featurestore), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.update_featurestore(request) @@ -1449,28 +1511,26 @@ async def test_update_featurestore_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'featurestore.name=featurestore.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "featurestore.name=featurestore.name/value", + ) in kw["metadata"] def test_update_featurestore_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_featurestore), - '__call__') as call: + type(client.transport.update_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_featurestore( - featurestore=gca_featurestore.Featurestore(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + featurestore=gca_featurestore.Featurestore(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1478,23 +1538,21 @@ def test_update_featurestore_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + assert args[0].featurestore == gca_featurestore.Featurestore(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_featurestore_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_featurestore( featurestore_service.UpdateFeaturestoreRequest(), - featurestore=gca_featurestore.Featurestore(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + featurestore=gca_featurestore.Featurestore(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1506,19 +1564,19 @@ async def test_update_featurestore_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_featurestore), - '__call__') as call: + type(client.transport.update_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_featurestore( - featurestore=gca_featurestore.Featurestore(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + featurestore=gca_featurestore.Featurestore(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1526,9 +1584,9 @@ async def test_update_featurestore_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].featurestore == gca_featurestore.Featurestore(name='name_value') + assert args[0].featurestore == gca_featurestore.Featurestore(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -1542,15 +1600,16 @@ async def test_update_featurestore_flattened_error_async(): with pytest.raises(ValueError): await client.update_featurestore( featurestore_service.UpdateFeaturestoreRequest(), - featurestore=gca_featurestore.Featurestore(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + featurestore=gca_featurestore.Featurestore(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_featurestore(transport: str = 'grpc', request_type=featurestore_service.DeleteFeaturestoreRequest): +def test_delete_featurestore( + transport: str = "grpc", request_type=featurestore_service.DeleteFeaturestoreRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1559,10 +1618,10 @@ def test_delete_featurestore(transport: str = 'grpc', request_type=featurestore_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_featurestore), - '__call__') as call: + type(client.transport.delete_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_featurestore(request) @@ -1584,25 +1643,27 @@ def test_delete_featurestore_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_featurestore), - '__call__') as call: + type(client.transport.delete_featurestore), "__call__" + ) as call: client.delete_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.DeleteFeaturestoreRequest() + @pytest.mark.asyncio -async def test_delete_featurestore_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.DeleteFeaturestoreRequest): +async def test_delete_featurestore_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.DeleteFeaturestoreRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1611,11 +1672,11 @@ async def test_delete_featurestore_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_featurestore), - '__call__') as call: + type(client.transport.delete_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_featurestore(request) @@ -1636,20 +1697,18 @@ async def test_delete_featurestore_async_from_dict(): def test_delete_featurestore_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.DeleteFeaturestoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_featurestore), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_featurestore), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_featurestore(request) @@ -1660,10 +1719,7 @@ def test_delete_featurestore_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1675,13 +1731,15 @@ async def test_delete_featurestore_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.DeleteFeaturestoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_featurestore), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_featurestore), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_featurestore(request) @@ -1692,49 +1750,39 @@ async def test_delete_featurestore_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_featurestore_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_featurestore), - '__call__') as call: + type(client.transport.delete_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_featurestore( - name='name_value', - ) + client.delete_featurestore(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_featurestore_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_featurestore( - featurestore_service.DeleteFeaturestoreRequest(), - name='name_value', + featurestore_service.DeleteFeaturestoreRequest(), name="name_value", ) @@ -1746,26 +1794,24 @@ async def test_delete_featurestore_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_featurestore), - '__call__') as call: + type(client.transport.delete_featurestore), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_featurestore( - name='name_value', - ) + response = await client.delete_featurestore(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -1778,15 +1824,15 @@ async def test_delete_featurestore_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_featurestore( - featurestore_service.DeleteFeaturestoreRequest(), - name='name_value', + featurestore_service.DeleteFeaturestoreRequest(), name="name_value", ) -def test_create_entity_type(transport: str = 'grpc', request_type=featurestore_service.CreateEntityTypeRequest): +def test_create_entity_type( + transport: str = "grpc", request_type=featurestore_service.CreateEntityTypeRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1795,10 +1841,10 @@ def test_create_entity_type(transport: str = 'grpc', request_type=featurestore_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_entity_type), - '__call__') as call: + type(client.transport.create_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_entity_type(request) @@ -1820,25 +1866,27 @@ def test_create_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_entity_type), - '__call__') as call: + type(client.transport.create_entity_type), "__call__" + ) as call: client.create_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.CreateEntityTypeRequest() + @pytest.mark.asyncio -async def test_create_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.CreateEntityTypeRequest): +async def test_create_entity_type_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.CreateEntityTypeRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1847,11 +1895,11 @@ async def test_create_entity_type_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_entity_type), - '__call__') as call: + type(client.transport.create_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_entity_type(request) @@ -1872,20 +1920,18 @@ async def test_create_entity_type_async_from_dict(): def test_create_entity_type_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.CreateEntityTypeRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_entity_type), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_entity_type), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_entity_type(request) @@ -1896,10 +1942,7 @@ def test_create_entity_type_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1911,13 +1954,15 @@ async def test_create_entity_type_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.CreateEntityTypeRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_entity_type), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_entity_type), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_entity_type(request) @@ -1928,29 +1973,24 @@ async def test_create_entity_type_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_entity_type_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_entity_type), - '__call__') as call: + type(client.transport.create_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_entity_type( - parent='parent_value', - entity_type=gca_entity_type.EntityType(name='name_value'), + parent="parent_value", + entity_type=gca_entity_type.EntityType(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -1958,23 +1998,21 @@ def test_create_entity_type_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + assert args[0].entity_type == gca_entity_type.EntityType(name="name_value") def test_create_entity_type_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_entity_type( featurestore_service.CreateEntityTypeRequest(), - parent='parent_value', - entity_type=gca_entity_type.EntityType(name='name_value'), + parent="parent_value", + entity_type=gca_entity_type.EntityType(name="name_value"), ) @@ -1986,19 +2024,19 @@ async def test_create_entity_type_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_entity_type), - '__call__') as call: + type(client.transport.create_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_entity_type( - parent='parent_value', - entity_type=gca_entity_type.EntityType(name='name_value'), + parent="parent_value", + entity_type=gca_entity_type.EntityType(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -2006,9 +2044,9 @@ async def test_create_entity_type_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + assert args[0].entity_type == gca_entity_type.EntityType(name="name_value") @pytest.mark.asyncio @@ -2022,15 +2060,16 @@ async def test_create_entity_type_flattened_error_async(): with pytest.raises(ValueError): await client.create_entity_type( featurestore_service.CreateEntityTypeRequest(), - parent='parent_value', - entity_type=gca_entity_type.EntityType(name='name_value'), + parent="parent_value", + entity_type=gca_entity_type.EntityType(name="name_value"), ) -def test_get_entity_type(transport: str = 'grpc', request_type=featurestore_service.GetEntityTypeRequest): +def test_get_entity_type( + transport: str = "grpc", request_type=featurestore_service.GetEntityTypeRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2038,17 +2077,10 @@ def test_get_entity_type(transport: str = 'grpc', request_type=featurestore_serv request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_entity_type), - '__call__') as call: + with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = entity_type.EntityType( - name='name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", description="description_value", etag="etag_value", ) response = client.get_entity_type(request) @@ -2063,11 +2095,11 @@ def test_get_entity_type(transport: str = 'grpc', request_type=featurestore_serv assert isinstance(response, entity_type.EntityType) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_entity_type_from_dict(): @@ -2078,25 +2110,25 @@ def test_get_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_entity_type), - '__call__') as call: + with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: client.get_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.GetEntityTypeRequest() + @pytest.mark.asyncio -async def test_get_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.GetEntityTypeRequest): +async def test_get_entity_type_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.GetEntityTypeRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2104,15 +2136,13 @@ async def test_get_entity_type_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_entity_type), - '__call__') as call: + with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(entity_type.EntityType( - name='name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + entity_type.EntityType( + name="name_value", description="description_value", etag="etag_value", + ) + ) response = await client.get_entity_type(request) @@ -2125,11 +2155,11 @@ async def test_get_entity_type_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, entity_type.EntityType) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -2138,19 +2168,15 @@ async def test_get_entity_type_async_from_dict(): def test_get_entity_type_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.GetEntityTypeRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_entity_type), - '__call__') as call: + with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: call.return_value = entity_type.EntityType() client.get_entity_type(request) @@ -2162,10 +2188,7 @@ def test_get_entity_type_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -2177,13 +2200,13 @@ async def test_get_entity_type_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.GetEntityTypeRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_entity_type), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(entity_type.EntityType()) + with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + entity_type.EntityType() + ) await client.get_entity_type(request) @@ -2194,49 +2217,37 @@ async def test_get_entity_type_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_entity_type_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_entity_type), - '__call__') as call: + with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = entity_type.EntityType() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_entity_type( - name='name_value', - ) + client.get_entity_type(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_entity_type_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_entity_type( - featurestore_service.GetEntityTypeRequest(), - name='name_value', + featurestore_service.GetEntityTypeRequest(), name="name_value", ) @@ -2247,25 +2258,23 @@ async def test_get_entity_type_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_entity_type), - '__call__') as call: + with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = entity_type.EntityType() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(entity_type.EntityType()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + entity_type.EntityType() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_entity_type( - name='name_value', - ) + response = await client.get_entity_type(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -2278,15 +2287,15 @@ async def test_get_entity_type_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_entity_type( - featurestore_service.GetEntityTypeRequest(), - name='name_value', + featurestore_service.GetEntityTypeRequest(), name="name_value", ) -def test_list_entity_types(transport: str = 'grpc', request_type=featurestore_service.ListEntityTypesRequest): +def test_list_entity_types( + transport: str = "grpc", request_type=featurestore_service.ListEntityTypesRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2295,12 +2304,11 @@ def test_list_entity_types(transport: str = 'grpc', request_type=featurestore_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListEntityTypesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_entity_types(request) @@ -2315,7 +2323,7 @@ def test_list_entity_types(transport: str = 'grpc', request_type=featurestore_se assert isinstance(response, pagers.ListEntityTypesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_entity_types_from_dict(): @@ -2326,25 +2334,27 @@ def test_list_entity_types_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: client.list_entity_types() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.ListEntityTypesRequest() + @pytest.mark.asyncio -async def test_list_entity_types_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ListEntityTypesRequest): +async def test_list_entity_types_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.ListEntityTypesRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2353,12 +2363,14 @@ async def test_list_entity_types_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListEntityTypesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListEntityTypesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_entity_types(request) @@ -2371,7 +2383,7 @@ async def test_list_entity_types_async(transport: str = 'grpc_asyncio', request_ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEntityTypesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2380,19 +2392,17 @@ async def test_list_entity_types_async_from_dict(): def test_list_entity_types_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ListEntityTypesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: call.return_value = featurestore_service.ListEntityTypesResponse() client.list_entity_types(request) @@ -2404,10 +2414,7 @@ def test_list_entity_types_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -2419,13 +2426,15 @@ async def test_list_entity_types_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ListEntityTypesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListEntityTypesResponse()) + type(client.transport.list_entity_types), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListEntityTypesResponse() + ) await client.list_entity_types(request) @@ -2436,49 +2445,39 @@ async def test_list_entity_types_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_entity_types_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListEntityTypesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_entity_types( - parent='parent_value', - ) + client.list_entity_types(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_entity_types_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_entity_types( - featurestore_service.ListEntityTypesRequest(), - parent='parent_value', + featurestore_service.ListEntityTypesRequest(), parent="parent_value", ) @@ -2490,24 +2489,24 @@ async def test_list_entity_types_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListEntityTypesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListEntityTypesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListEntityTypesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_entity_types( - parent='parent_value', - ) + response = await client.list_entity_types(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -2520,20 +2519,17 @@ async def test_list_entity_types_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_entity_types( - featurestore_service.ListEntityTypesRequest(), - parent='parent_value', + featurestore_service.ListEntityTypesRequest(), parent="parent_value", ) def test_list_entity_types_pager(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListEntityTypesResponse( @@ -2542,32 +2538,23 @@ def test_list_entity_types_pager(): entity_type.EntityType(), entity_type.EntityType(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListEntityTypesResponse( - entity_types=[], - next_page_token='def', + entity_types=[], next_page_token="def", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - ], - next_page_token='ghi', + entity_types=[entity_type.EntityType(),], next_page_token="ghi", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - entity_type.EntityType(), - ], + entity_types=[entity_type.EntityType(), entity_type.EntityType(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_entity_types(request={}) @@ -2575,18 +2562,16 @@ def test_list_entity_types_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, entity_type.EntityType) - for i in results) + assert all(isinstance(i, entity_type.EntityType) for i in results) + def test_list_entity_types_pages(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__') as call: + type(client.transport.list_entity_types), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListEntityTypesResponse( @@ -2595,30 +2580,24 @@ def test_list_entity_types_pages(): entity_type.EntityType(), entity_type.EntityType(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListEntityTypesResponse( - entity_types=[], - next_page_token='def', + entity_types=[], next_page_token="def", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - ], - next_page_token='ghi', + entity_types=[entity_type.EntityType(),], next_page_token="ghi", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - entity_type.EntityType(), - ], + entity_types=[entity_type.EntityType(), entity_type.EntityType(),], ), RuntimeError, ) pages = list(client.list_entity_types(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_entity_types_async_pager(): client = FeaturestoreServiceAsyncClient( @@ -2627,8 +2606,10 @@ async def test_list_entity_types_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_entity_types), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListEntityTypesResponse( @@ -2637,35 +2618,28 @@ async def test_list_entity_types_async_pager(): entity_type.EntityType(), entity_type.EntityType(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListEntityTypesResponse( - entity_types=[], - next_page_token='def', + entity_types=[], next_page_token="def", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - ], - next_page_token='ghi', + entity_types=[entity_type.EntityType(),], next_page_token="ghi", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - entity_type.EntityType(), - ], + entity_types=[entity_type.EntityType(), entity_type.EntityType(),], ), RuntimeError, ) async_pager = await client.list_entity_types(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, entity_type.EntityType) - for i in responses) + assert all(isinstance(i, entity_type.EntityType) for i in responses) + @pytest.mark.asyncio async def test_list_entity_types_async_pages(): @@ -2675,8 +2649,10 @@ async def test_list_entity_types_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_entity_types), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_entity_types), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListEntityTypesResponse( @@ -2685,37 +2661,31 @@ async def test_list_entity_types_async_pages(): entity_type.EntityType(), entity_type.EntityType(), ], - next_page_token='abc', + next_page_token="abc", ), featurestore_service.ListEntityTypesResponse( - entity_types=[], - next_page_token='def', + entity_types=[], next_page_token="def", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - ], - next_page_token='ghi', + entity_types=[entity_type.EntityType(),], next_page_token="ghi", ), featurestore_service.ListEntityTypesResponse( - entity_types=[ - entity_type.EntityType(), - entity_type.EntityType(), - ], + entity_types=[entity_type.EntityType(), entity_type.EntityType(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_entity_types(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_entity_type(transport: str = 'grpc', request_type=featurestore_service.UpdateEntityTypeRequest): +def test_update_entity_type( + transport: str = "grpc", request_type=featurestore_service.UpdateEntityTypeRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2724,16 +2694,11 @@ def test_update_entity_type(transport: str = 'grpc', request_type=featurestore_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_entity_type), - '__call__') as call: + type(client.transport.update_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_entity_type.EntityType( - name='name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", description="description_value", etag="etag_value", ) response = client.update_entity_type(request) @@ -2748,11 +2713,11 @@ def test_update_entity_type(transport: str = 'grpc', request_type=featurestore_s assert isinstance(response, gca_entity_type.EntityType) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_entity_type_from_dict(): @@ -2763,25 +2728,27 @@ def test_update_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_entity_type), - '__call__') as call: + type(client.transport.update_entity_type), "__call__" + ) as call: client.update_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateEntityTypeRequest() + @pytest.mark.asyncio -async def test_update_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.UpdateEntityTypeRequest): +async def test_update_entity_type_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.UpdateEntityTypeRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2790,14 +2757,14 @@ async def test_update_entity_type_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_entity_type), - '__call__') as call: + type(client.transport.update_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_entity_type.EntityType( - name='name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_entity_type.EntityType( + name="name_value", description="description_value", etag="etag_value", + ) + ) response = await client.update_entity_type(request) @@ -2810,11 +2777,11 @@ async def test_update_entity_type_async(transport: str = 'grpc_asyncio', request # Establish that the response is the type that we expect. assert isinstance(response, gca_entity_type.EntityType) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -2823,19 +2790,17 @@ async def test_update_entity_type_async_from_dict(): def test_update_entity_type_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.UpdateEntityTypeRequest() - request.entity_type.name = 'entity_type.name/value' + request.entity_type.name = "entity_type.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_entity_type), - '__call__') as call: + type(client.transport.update_entity_type), "__call__" + ) as call: call.return_value = gca_entity_type.EntityType() client.update_entity_type(request) @@ -2847,10 +2812,9 @@ def test_update_entity_type_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type.name=entity_type.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type.name=entity_type.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio @@ -2862,13 +2826,15 @@ async def test_update_entity_type_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.UpdateEntityTypeRequest() - request.entity_type.name = 'entity_type.name/value' + request.entity_type.name = "entity_type.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_entity_type), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_entity_type.EntityType()) + type(client.transport.update_entity_type), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_entity_type.EntityType() + ) await client.update_entity_type(request) @@ -2879,29 +2845,26 @@ async def test_update_entity_type_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type.name=entity_type.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type.name=entity_type.name/value",) in kw[ + "metadata" + ] def test_update_entity_type_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_entity_type), - '__call__') as call: + type(client.transport.update_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_entity_type.EntityType() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_entity_type( - entity_type=gca_entity_type.EntityType(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + entity_type=gca_entity_type.EntityType(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2909,23 +2872,21 @@ def test_update_entity_type_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + assert args[0].entity_type == gca_entity_type.EntityType(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_entity_type_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_entity_type( featurestore_service.UpdateEntityTypeRequest(), - entity_type=gca_entity_type.EntityType(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + entity_type=gca_entity_type.EntityType(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -2937,17 +2898,19 @@ async def test_update_entity_type_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_entity_type), - '__call__') as call: + type(client.transport.update_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_entity_type.EntityType() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_entity_type.EntityType()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_entity_type.EntityType() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_entity_type( - entity_type=gca_entity_type.EntityType(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + entity_type=gca_entity_type.EntityType(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2955,9 +2918,9 @@ async def test_update_entity_type_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].entity_type == gca_entity_type.EntityType(name='name_value') + assert args[0].entity_type == gca_entity_type.EntityType(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -2971,15 +2934,16 @@ async def test_update_entity_type_flattened_error_async(): with pytest.raises(ValueError): await client.update_entity_type( featurestore_service.UpdateEntityTypeRequest(), - entity_type=gca_entity_type.EntityType(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + entity_type=gca_entity_type.EntityType(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_entity_type(transport: str = 'grpc', request_type=featurestore_service.DeleteEntityTypeRequest): +def test_delete_entity_type( + transport: str = "grpc", request_type=featurestore_service.DeleteEntityTypeRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2988,10 +2952,10 @@ def test_delete_entity_type(transport: str = 'grpc', request_type=featurestore_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_entity_type), - '__call__') as call: + type(client.transport.delete_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_entity_type(request) @@ -3013,25 +2977,27 @@ def test_delete_entity_type_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_entity_type), - '__call__') as call: + type(client.transport.delete_entity_type), "__call__" + ) as call: client.delete_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.DeleteEntityTypeRequest() + @pytest.mark.asyncio -async def test_delete_entity_type_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.DeleteEntityTypeRequest): +async def test_delete_entity_type_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.DeleteEntityTypeRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3040,11 +3006,11 @@ async def test_delete_entity_type_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_entity_type), - '__call__') as call: + type(client.transport.delete_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_entity_type(request) @@ -3065,20 +3031,18 @@ async def test_delete_entity_type_async_from_dict(): def test_delete_entity_type_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.DeleteEntityTypeRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_entity_type), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_entity_type), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_entity_type(request) @@ -3089,10 +3053,7 @@ def test_delete_entity_type_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -3104,13 +3065,15 @@ async def test_delete_entity_type_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.DeleteEntityTypeRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_entity_type), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_entity_type), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_entity_type(request) @@ -3121,49 +3084,39 @@ async def test_delete_entity_type_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_entity_type_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_entity_type), - '__call__') as call: + type(client.transport.delete_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_entity_type( - name='name_value', - ) + client.delete_entity_type(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_entity_type_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_entity_type( - featurestore_service.DeleteEntityTypeRequest(), - name='name_value', + featurestore_service.DeleteEntityTypeRequest(), name="name_value", ) @@ -3175,26 +3128,24 @@ async def test_delete_entity_type_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_entity_type), - '__call__') as call: + type(client.transport.delete_entity_type), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_entity_type( - name='name_value', - ) + response = await client.delete_entity_type(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -3207,15 +3158,15 @@ async def test_delete_entity_type_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_entity_type( - featurestore_service.DeleteEntityTypeRequest(), - name='name_value', + featurestore_service.DeleteEntityTypeRequest(), name="name_value", ) -def test_create_feature(transport: str = 'grpc', request_type=featurestore_service.CreateFeatureRequest): +def test_create_feature( + transport: str = "grpc", request_type=featurestore_service.CreateFeatureRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3223,11 +3174,9 @@ def test_create_feature(transport: str = 'grpc', request_type=featurestore_servi request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.create_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_feature(request) @@ -3249,25 +3198,25 @@ def test_create_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.create_feature), "__call__") as call: client.create_feature() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.CreateFeatureRequest() + @pytest.mark.asyncio -async def test_create_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.CreateFeatureRequest): +async def test_create_feature_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.CreateFeatureRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3275,12 +3224,10 @@ async def test_create_feature_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.create_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_feature(request) @@ -3301,20 +3248,16 @@ async def test_create_feature_async_from_dict(): def test_create_feature_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.CreateFeatureRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_feature), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_feature(request) @@ -3325,10 +3268,7 @@ def test_create_feature_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -3340,13 +3280,13 @@ async def test_create_feature_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.CreateFeatureRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_feature), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_feature(request) @@ -3357,29 +3297,21 @@ async def test_create_feature_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_feature_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.create_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_feature( - parent='parent_value', - feature=gca_feature.Feature(name='name_value'), + parent="parent_value", feature=gca_feature.Feature(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -3387,23 +3319,21 @@ def test_create_feature_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].feature == gca_feature.Feature(name='name_value') + assert args[0].feature == gca_feature.Feature(name="name_value") def test_create_feature_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_feature( featurestore_service.CreateFeatureRequest(), - parent='parent_value', - feature=gca_feature.Feature(name='name_value'), + parent="parent_value", + feature=gca_feature.Feature(name="name_value"), ) @@ -3414,20 +3344,17 @@ async def test_create_feature_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.create_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_feature( - parent='parent_value', - feature=gca_feature.Feature(name='name_value'), + parent="parent_value", feature=gca_feature.Feature(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -3435,9 +3362,9 @@ async def test_create_feature_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].feature == gca_feature.Feature(name='name_value') + assert args[0].feature == gca_feature.Feature(name="name_value") @pytest.mark.asyncio @@ -3451,15 +3378,17 @@ async def test_create_feature_flattened_error_async(): with pytest.raises(ValueError): await client.create_feature( featurestore_service.CreateFeatureRequest(), - parent='parent_value', - feature=gca_feature.Feature(name='name_value'), + parent="parent_value", + feature=gca_feature.Feature(name="name_value"), ) -def test_batch_create_features(transport: str = 'grpc', request_type=featurestore_service.BatchCreateFeaturesRequest): +def test_batch_create_features( + transport: str = "grpc", + request_type=featurestore_service.BatchCreateFeaturesRequest, +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3468,10 +3397,10 @@ def test_batch_create_features(transport: str = 'grpc', request_type=featurestor # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_create_features), - '__call__') as call: + type(client.transport.batch_create_features), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_create_features(request) @@ -3493,25 +3422,27 @@ def test_batch_create_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_create_features), - '__call__') as call: + type(client.transport.batch_create_features), "__call__" + ) as call: client.batch_create_features() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.BatchCreateFeaturesRequest() + @pytest.mark.asyncio -async def test_batch_create_features_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.BatchCreateFeaturesRequest): +async def test_batch_create_features_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.BatchCreateFeaturesRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3520,11 +3451,11 @@ async def test_batch_create_features_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_create_features), - '__call__') as call: + type(client.transport.batch_create_features), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_create_features(request) @@ -3545,20 +3476,18 @@ async def test_batch_create_features_async_from_dict(): def test_batch_create_features_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.BatchCreateFeaturesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_create_features), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_create_features), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_create_features(request) @@ -3569,10 +3498,7 @@ def test_batch_create_features_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -3584,13 +3510,15 @@ async def test_batch_create_features_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.BatchCreateFeaturesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_create_features), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_create_features), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_create_features(request) @@ -3601,29 +3529,24 @@ async def test_batch_create_features_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_batch_create_features_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_create_features), - '__call__') as call: + type(client.transport.batch_create_features), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_create_features( - parent='parent_value', - requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + parent="parent_value", + requests=[featurestore_service.CreateFeatureRequest(parent="parent_value")], ) # Establish that the underlying call was made with the expected @@ -3631,23 +3554,23 @@ def test_batch_create_features_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].requests == [featurestore_service.CreateFeatureRequest(parent='parent_value')] + assert args[0].requests == [ + featurestore_service.CreateFeatureRequest(parent="parent_value") + ] def test_batch_create_features_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_create_features( featurestore_service.BatchCreateFeaturesRequest(), - parent='parent_value', - requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + parent="parent_value", + requests=[featurestore_service.CreateFeatureRequest(parent="parent_value")], ) @@ -3659,19 +3582,19 @@ async def test_batch_create_features_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_create_features), - '__call__') as call: + type(client.transport.batch_create_features), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_create_features( - parent='parent_value', - requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + parent="parent_value", + requests=[featurestore_service.CreateFeatureRequest(parent="parent_value")], ) # Establish that the underlying call was made with the expected @@ -3679,9 +3602,11 @@ async def test_batch_create_features_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].requests == [featurestore_service.CreateFeatureRequest(parent='parent_value')] + assert args[0].requests == [ + featurestore_service.CreateFeatureRequest(parent="parent_value") + ] @pytest.mark.asyncio @@ -3695,15 +3620,16 @@ async def test_batch_create_features_flattened_error_async(): with pytest.raises(ValueError): await client.batch_create_features( featurestore_service.BatchCreateFeaturesRequest(), - parent='parent_value', - requests=[featurestore_service.CreateFeatureRequest(parent='parent_value')], + parent="parent_value", + requests=[featurestore_service.CreateFeatureRequest(parent="parent_value")], ) -def test_get_feature(transport: str = 'grpc', request_type=featurestore_service.GetFeatureRequest): +def test_get_feature( + transport: str = "grpc", request_type=featurestore_service.GetFeatureRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3711,19 +3637,13 @@ def test_get_feature(transport: str = 'grpc', request_type=featurestore_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.get_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = feature.Feature( - name='name_value', - - description='description_value', - + name="name_value", + description="description_value", value_type=feature.Feature.ValueType.BOOL, - - etag='etag_value', - + etag="etag_value", ) response = client.get_feature(request) @@ -3738,13 +3658,13 @@ def test_get_feature(transport: str = 'grpc', request_type=featurestore_service. assert isinstance(response, feature.Feature) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" assert response.value_type == feature.Feature.ValueType.BOOL - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_feature_from_dict(): @@ -3755,25 +3675,24 @@ def test_get_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.get_feature), "__call__") as call: client.get_feature() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.GetFeatureRequest() + @pytest.mark.asyncio -async def test_get_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.GetFeatureRequest): +async def test_get_feature_async( + transport: str = "grpc_asyncio", request_type=featurestore_service.GetFeatureRequest +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3781,16 +3700,16 @@ async def test_get_feature_async(transport: str = 'grpc_asyncio', request_type=f request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.get_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(feature.Feature( - name='name_value', - description='description_value', - value_type=feature.Feature.ValueType.BOOL, - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + feature.Feature( + name="name_value", + description="description_value", + value_type=feature.Feature.ValueType.BOOL, + etag="etag_value", + ) + ) response = await client.get_feature(request) @@ -3803,13 +3722,13 @@ async def test_get_feature_async(transport: str = 'grpc_asyncio', request_type=f # Establish that the response is the type that we expect. assert isinstance(response, feature.Feature) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" assert response.value_type == feature.Feature.ValueType.BOOL - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -3818,19 +3737,15 @@ async def test_get_feature_async_from_dict(): def test_get_feature_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.GetFeatureRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.get_feature), "__call__") as call: call.return_value = feature.Feature() client.get_feature(request) @@ -3842,10 +3757,7 @@ def test_get_feature_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -3857,12 +3769,10 @@ async def test_get_feature_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.GetFeatureRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.get_feature), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(feature.Feature()) await client.get_feature(request) @@ -3874,49 +3784,37 @@ async def test_get_feature_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_feature_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.get_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = feature.Feature() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_feature( - name='name_value', - ) + client.get_feature(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_feature_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_feature( - featurestore_service.GetFeatureRequest(), - name='name_value', + featurestore_service.GetFeatureRequest(), name="name_value", ) @@ -3927,25 +3825,21 @@ async def test_get_feature_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.get_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = feature.Feature() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(feature.Feature()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_feature( - name='name_value', - ) + response = await client.get_feature(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -3958,15 +3852,15 @@ async def test_get_feature_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_feature( - featurestore_service.GetFeatureRequest(), - name='name_value', + featurestore_service.GetFeatureRequest(), name="name_value", ) -def test_list_features(transport: str = 'grpc', request_type=featurestore_service.ListFeaturesRequest): +def test_list_features( + transport: str = "grpc", request_type=featurestore_service.ListFeaturesRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3974,13 +3868,10 @@ def test_list_features(transport: str = 'grpc', request_type=featurestore_servic request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListFeaturesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_features(request) @@ -3995,7 +3886,7 @@ def test_list_features(transport: str = 'grpc', request_type=featurestore_servic assert isinstance(response, pagers.ListFeaturesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_features_from_dict(): @@ -4006,25 +3897,25 @@ def test_list_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: client.list_features() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.ListFeaturesRequest() + @pytest.mark.asyncio -async def test_list_features_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ListFeaturesRequest): +async def test_list_features_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.ListFeaturesRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4032,13 +3923,13 @@ async def test_list_features_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListFeaturesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_features(request) @@ -4051,7 +3942,7 @@ async def test_list_features_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListFeaturesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -4060,19 +3951,15 @@ async def test_list_features_async_from_dict(): def test_list_features_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ListFeaturesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: call.return_value = featurestore_service.ListFeaturesResponse() client.list_features(request) @@ -4084,10 +3971,7 @@ def test_list_features_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -4099,13 +3983,13 @@ async def test_list_features_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ListFeaturesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturesResponse()) + with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListFeaturesResponse() + ) await client.list_features(request) @@ -4116,49 +4000,37 @@ async def test_list_features_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_features_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListFeaturesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_features( - parent='parent_value', - ) + client.list_features(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_features_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_features( - featurestore_service.ListFeaturesRequest(), - parent='parent_value', + featurestore_service.ListFeaturesRequest(), parent="parent_value", ) @@ -4169,25 +4041,23 @@ async def test_list_features_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.ListFeaturesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.ListFeaturesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.ListFeaturesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_features( - parent='parent_value', - ) + response = await client.list_features(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -4200,54 +4070,36 @@ async def test_list_features_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_features( - featurestore_service.ListFeaturesRequest(), - parent='parent_value', + featurestore_service.ListFeaturesRequest(), parent="parent_value", ) def test_list_features_pager(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.ListFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_features(request={}) @@ -4255,50 +4107,36 @@ def test_list_features_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, feature.Feature) - for i in results) + assert all(isinstance(i, feature.Feature) for i in results) + def test_list_features_pages(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_features), - '__call__') as call: + with mock.patch.object(type(client.transport.list_features), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.ListFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) pages = list(client.list_features(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_features_async_pager(): client = FeaturestoreServiceAsyncClient( @@ -4307,45 +4145,34 @@ async def test_list_features_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_features), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_features), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.ListFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) async_pager = await client.list_features(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, feature.Feature) - for i in responses) + assert all(isinstance(i, feature.Feature) for i in responses) + @pytest.mark.asyncio async def test_list_features_async_pages(): @@ -4355,47 +4182,37 @@ async def test_list_features_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_features), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_features), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.ListFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.ListFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_features(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_feature(transport: str = 'grpc', request_type=featurestore_service.UpdateFeatureRequest): +def test_update_feature( + transport: str = "grpc", request_type=featurestore_service.UpdateFeatureRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4403,19 +4220,13 @@ def test_update_feature(transport: str = 'grpc', request_type=featurestore_servi request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.update_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_feature.Feature( - name='name_value', - - description='description_value', - + name="name_value", + description="description_value", value_type=gca_feature.Feature.ValueType.BOOL, - - etag='etag_value', - + etag="etag_value", ) response = client.update_feature(request) @@ -4430,13 +4241,13 @@ def test_update_feature(transport: str = 'grpc', request_type=featurestore_servi assert isinstance(response, gca_feature.Feature) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" assert response.value_type == gca_feature.Feature.ValueType.BOOL - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_feature_from_dict(): @@ -4447,25 +4258,25 @@ def test_update_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.update_feature), "__call__") as call: client.update_feature() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeatureRequest() + @pytest.mark.asyncio -async def test_update_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.UpdateFeatureRequest): +async def test_update_feature_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.UpdateFeatureRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4473,16 +4284,16 @@ async def test_update_feature_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.update_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_feature.Feature( - name='name_value', - description='description_value', - value_type=gca_feature.Feature.ValueType.BOOL, - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_feature.Feature( + name="name_value", + description="description_value", + value_type=gca_feature.Feature.ValueType.BOOL, + etag="etag_value", + ) + ) response = await client.update_feature(request) @@ -4495,13 +4306,13 @@ async def test_update_feature_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_feature.Feature) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.description == 'description_value' + assert response.description == "description_value" assert response.value_type == gca_feature.Feature.ValueType.BOOL - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -4510,19 +4321,15 @@ async def test_update_feature_async_from_dict(): def test_update_feature_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.UpdateFeatureRequest() - request.feature.name = 'feature.name/value' + request.feature.name = "feature.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.update_feature), "__call__") as call: call.return_value = gca_feature.Feature() client.update_feature(request) @@ -4534,10 +4341,9 @@ def test_update_feature_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'feature.name=feature.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "feature.name=feature.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio @@ -4549,12 +4355,10 @@ async def test_update_feature_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.UpdateFeatureRequest() - request.feature.name = 'feature.name/value' + request.feature.name = "feature.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.update_feature), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_feature.Feature()) await client.update_feature(request) @@ -4566,29 +4370,24 @@ async def test_update_feature_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'feature.name=feature.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "feature.name=feature.name/value",) in kw[ + "metadata" + ] def test_update_feature_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.update_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_feature.Feature() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_feature( - feature=gca_feature.Feature(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + feature=gca_feature.Feature(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -4596,23 +4395,21 @@ def test_update_feature_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].feature == gca_feature.Feature(name='name_value') + assert args[0].feature == gca_feature.Feature(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_feature_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_feature( featurestore_service.UpdateFeatureRequest(), - feature=gca_feature.Feature(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + feature=gca_feature.Feature(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -4623,9 +4420,7 @@ async def test_update_feature_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.update_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_feature.Feature() @@ -4633,8 +4428,8 @@ async def test_update_feature_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_feature( - feature=gca_feature.Feature(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + feature=gca_feature.Feature(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -4642,9 +4437,9 @@ async def test_update_feature_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].feature == gca_feature.Feature(name='name_value') + assert args[0].feature == gca_feature.Feature(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -4658,15 +4453,16 @@ async def test_update_feature_flattened_error_async(): with pytest.raises(ValueError): await client.update_feature( featurestore_service.UpdateFeatureRequest(), - feature=gca_feature.Feature(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + feature=gca_feature.Feature(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_feature(transport: str = 'grpc', request_type=featurestore_service.DeleteFeatureRequest): +def test_delete_feature( + transport: str = "grpc", request_type=featurestore_service.DeleteFeatureRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4674,11 +4470,9 @@ def test_delete_feature(transport: str = 'grpc', request_type=featurestore_servi request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_feature(request) @@ -4700,25 +4494,25 @@ def test_delete_feature_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: client.delete_feature() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.DeleteFeatureRequest() + @pytest.mark.asyncio -async def test_delete_feature_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.DeleteFeatureRequest): +async def test_delete_feature_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.DeleteFeatureRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4726,12 +4520,10 @@ async def test_delete_feature_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_feature(request) @@ -4752,20 +4544,16 @@ async def test_delete_feature_async_from_dict(): def test_delete_feature_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.DeleteFeatureRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_feature), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_feature(request) @@ -4776,10 +4564,7 @@ def test_delete_feature_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -4791,13 +4576,13 @@ async def test_delete_feature_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.DeleteFeatureRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_feature), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_feature(request) @@ -4808,49 +4593,37 @@ async def test_delete_feature_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_feature_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_feature( - name='name_value', - ) + client.delete_feature(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_feature_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_feature( - featurestore_service.DeleteFeatureRequest(), - name='name_value', + featurestore_service.DeleteFeatureRequest(), name="name_value", ) @@ -4861,27 +4634,23 @@ async def test_delete_feature_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_feature), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_feature( - name='name_value', - ) + response = await client.delete_feature(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -4894,15 +4663,16 @@ async def test_delete_feature_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_feature( - featurestore_service.DeleteFeatureRequest(), - name='name_value', + featurestore_service.DeleteFeatureRequest(), name="name_value", ) -def test_import_feature_values(transport: str = 'grpc', request_type=featurestore_service.ImportFeatureValuesRequest): +def test_import_feature_values( + transport: str = "grpc", + request_type=featurestore_service.ImportFeatureValuesRequest, +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4911,10 +4681,10 @@ def test_import_feature_values(transport: str = 'grpc', request_type=featurestor # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.import_feature_values), - '__call__') as call: + type(client.transport.import_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.import_feature_values(request) @@ -4936,25 +4706,27 @@ def test_import_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.import_feature_values), - '__call__') as call: + type(client.transport.import_feature_values), "__call__" + ) as call: client.import_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.ImportFeatureValuesRequest() + @pytest.mark.asyncio -async def test_import_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.ImportFeatureValuesRequest): +async def test_import_feature_values_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.ImportFeatureValuesRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4963,11 +4735,11 @@ async def test_import_feature_values_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.import_feature_values), - '__call__') as call: + type(client.transport.import_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.import_feature_values(request) @@ -4988,20 +4760,18 @@ async def test_import_feature_values_async_from_dict(): def test_import_feature_values_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ImportFeatureValuesRequest() - request.entity_type = 'entity_type/value' + request.entity_type = "entity_type/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.import_feature_values), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.import_feature_values), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.import_feature_values(request) @@ -5012,10 +4782,7 @@ def test_import_feature_values_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type=entity_type/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] @pytest.mark.asyncio @@ -5027,13 +4794,15 @@ async def test_import_feature_values_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.ImportFeatureValuesRequest() - request.entity_type = 'entity_type/value' + request.entity_type = "entity_type/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.import_feature_values), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.import_feature_values), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.import_feature_values(request) @@ -5044,49 +4813,40 @@ async def test_import_feature_values_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'entity_type=entity_type/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] def test_import_feature_values_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.import_feature_values), - '__call__') as call: + type(client.transport.import_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.import_feature_values( - entity_type='entity_type_value', - ) + client.import_feature_values(entity_type="entity_type_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].entity_type == 'entity_type_value' + assert args[0].entity_type == "entity_type_value" def test_import_feature_values_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_feature_values( featurestore_service.ImportFeatureValuesRequest(), - entity_type='entity_type_value', + entity_type="entity_type_value", ) @@ -5098,26 +4858,24 @@ async def test_import_feature_values_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.import_feature_values), - '__call__') as call: + type(client.transport.import_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.import_feature_values( - entity_type='entity_type_value', - ) + response = await client.import_feature_values(entity_type="entity_type_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].entity_type == 'entity_type_value' + assert args[0].entity_type == "entity_type_value" @pytest.mark.asyncio @@ -5131,14 +4889,16 @@ async def test_import_feature_values_flattened_error_async(): with pytest.raises(ValueError): await client.import_feature_values( featurestore_service.ImportFeatureValuesRequest(), - entity_type='entity_type_value', + entity_type="entity_type_value", ) -def test_batch_read_feature_values(transport: str = 'grpc', request_type=featurestore_service.BatchReadFeatureValuesRequest): +def test_batch_read_feature_values( + transport: str = "grpc", + request_type=featurestore_service.BatchReadFeatureValuesRequest, +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5147,10 +4907,10 @@ def test_batch_read_feature_values(transport: str = 'grpc', request_type=feature # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_read_feature_values), - '__call__') as call: + type(client.transport.batch_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_read_feature_values(request) @@ -5172,25 +4932,27 @@ def test_batch_read_feature_values_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_read_feature_values), - '__call__') as call: + type(client.transport.batch_read_feature_values), "__call__" + ) as call: client.batch_read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.BatchReadFeatureValuesRequest() + @pytest.mark.asyncio -async def test_batch_read_feature_values_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.BatchReadFeatureValuesRequest): +async def test_batch_read_feature_values_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.BatchReadFeatureValuesRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5199,11 +4961,11 @@ async def test_batch_read_feature_values_async(transport: str = 'grpc_asyncio', # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_read_feature_values), - '__call__') as call: + type(client.transport.batch_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_read_feature_values(request) @@ -5224,20 +4986,18 @@ async def test_batch_read_feature_values_async_from_dict(): def test_batch_read_feature_values_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.BatchReadFeatureValuesRequest() - request.featurestore = 'featurestore/value' + request.featurestore = "featurestore/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_read_feature_values), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_read_feature_values), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_read_feature_values(request) @@ -5248,10 +5008,9 @@ def test_batch_read_feature_values_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'featurestore=featurestore/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "featurestore=featurestore/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio @@ -5263,13 +5022,15 @@ async def test_batch_read_feature_values_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.BatchReadFeatureValuesRequest() - request.featurestore = 'featurestore/value' + request.featurestore = "featurestore/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_read_feature_values), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_read_feature_values), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_read_feature_values(request) @@ -5280,49 +5041,42 @@ async def test_batch_read_feature_values_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'featurestore=featurestore/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "featurestore=featurestore/value",) in kw[ + "metadata" + ] def test_batch_read_feature_values_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_read_feature_values), - '__call__') as call: + type(client.transport.batch_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.batch_read_feature_values( - featurestore='featurestore_value', - ) + client.batch_read_feature_values(featurestore="featurestore_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].featurestore == 'featurestore_value' + assert args[0].featurestore == "featurestore_value" def test_batch_read_feature_values_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_read_feature_values( featurestore_service.BatchReadFeatureValuesRequest(), - featurestore='featurestore_value', + featurestore="featurestore_value", ) @@ -5334,18 +5088,18 @@ async def test_batch_read_feature_values_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_read_feature_values), - '__call__') as call: + type(client.transport.batch_read_feature_values), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_read_feature_values( - featurestore='featurestore_value', + featurestore="featurestore_value", ) # Establish that the underlying call was made with the expected @@ -5353,7 +5107,7 @@ async def test_batch_read_feature_values_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].featurestore == 'featurestore_value' + assert args[0].featurestore == "featurestore_value" @pytest.mark.asyncio @@ -5367,14 +5121,15 @@ async def test_batch_read_feature_values_flattened_error_async(): with pytest.raises(ValueError): await client.batch_read_feature_values( featurestore_service.BatchReadFeatureValuesRequest(), - featurestore='featurestore_value', + featurestore="featurestore_value", ) -def test_search_features(transport: str = 'grpc', request_type=featurestore_service.SearchFeaturesRequest): +def test_search_features( + transport: str = "grpc", request_type=featurestore_service.SearchFeaturesRequest +): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5382,13 +5137,10 @@ def test_search_features(transport: str = 'grpc', request_type=featurestore_serv request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.SearchFeaturesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_features(request) @@ -5403,7 +5155,7 @@ def test_search_features(transport: str = 'grpc', request_type=featurestore_serv assert isinstance(response, pagers.SearchFeaturesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_features_from_dict(): @@ -5414,25 +5166,25 @@ def test_search_features_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: client.search_features() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.SearchFeaturesRequest() + @pytest.mark.asyncio -async def test_search_features_async(transport: str = 'grpc_asyncio', request_type=featurestore_service.SearchFeaturesRequest): +async def test_search_features_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.SearchFeaturesRequest, +): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5440,13 +5192,13 @@ async def test_search_features_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.SearchFeaturesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.SearchFeaturesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.search_features(request) @@ -5459,7 +5211,7 @@ async def test_search_features_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchFeaturesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -5468,19 +5220,15 @@ async def test_search_features_async_from_dict(): def test_search_features_field_headers(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.SearchFeaturesRequest() - request.location = 'location/value' + request.location = "location/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: call.return_value = featurestore_service.SearchFeaturesResponse() client.search_features(request) @@ -5492,10 +5240,7 @@ def test_search_features_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'location=location/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "location=location/value",) in kw["metadata"] @pytest.mark.asyncio @@ -5507,13 +5252,13 @@ async def test_search_features_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = featurestore_service.SearchFeaturesRequest() - request.location = 'location/value' + request.location = "location/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.SearchFeaturesResponse()) + with mock.patch.object(type(client.transport.search_features), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.SearchFeaturesResponse() + ) await client.search_features(request) @@ -5524,49 +5269,37 @@ async def test_search_features_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'location=location/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "location=location/value",) in kw["metadata"] def test_search_features_flattened(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.SearchFeaturesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_features( - location='location_value', - ) + client.search_features(location="location_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].location == 'location_value' + assert args[0].location == "location_value" def test_search_features_flattened_error(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_features( - featurestore_service.SearchFeaturesRequest(), - location='location_value', + featurestore_service.SearchFeaturesRequest(), location="location_value", ) @@ -5577,25 +5310,23 @@ async def test_search_features_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = featurestore_service.SearchFeaturesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(featurestore_service.SearchFeaturesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + featurestore_service.SearchFeaturesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_features( - location='location_value', - ) + response = await client.search_features(location="location_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].location == 'location_value' + assert args[0].location == "location_value" @pytest.mark.asyncio @@ -5608,54 +5339,36 @@ async def test_search_features_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_features( - featurestore_service.SearchFeaturesRequest(), - location='location_value', + featurestore_service.SearchFeaturesRequest(), location="location_value", ) def test_search_features_pager(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.SearchFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('location', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("location", ""),)), ) pager = client.search_features(request={}) @@ -5663,50 +5376,36 @@ def test_search_features_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, feature.Feature) - for i in results) + assert all(isinstance(i, feature.Feature) for i in results) + def test_search_features_pages(): - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_features), - '__call__') as call: + with mock.patch.object(type(client.transport.search_features), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.SearchFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) pages = list(client.search_features(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_features_async_pager(): client = FeaturestoreServiceAsyncClient( @@ -5715,45 +5414,34 @@ async def test_search_features_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_features), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_features), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.SearchFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) async_pager = await client.search_features(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, feature.Feature) - for i in responses) + assert all(isinstance(i, feature.Feature) for i in responses) + @pytest.mark.asyncio async def test_search_features_async_pages(): @@ -5763,40 +5451,29 @@ async def test_search_features_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_features), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_features), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - feature.Feature(), - ], - next_page_token='abc', + features=[feature.Feature(), feature.Feature(), feature.Feature(),], + next_page_token="abc", ), featurestore_service.SearchFeaturesResponse( - features=[], - next_page_token='def', + features=[], next_page_token="def", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - ], - next_page_token='ghi', + features=[feature.Feature(),], next_page_token="ghi", ), featurestore_service.SearchFeaturesResponse( - features=[ - feature.Feature(), - feature.Feature(), - ], + features=[feature.Feature(), feature.Feature(),], ), RuntimeError, ) pages = [] async for page_ in (await client.search_features(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -5807,8 +5484,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -5827,8 +5503,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = FeaturestoreServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -5856,13 +5531,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.FeaturestoreServiceGrpcTransport, - transports.FeaturestoreServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreServiceGrpcTransport, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -5870,13 +5548,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.FeaturestoreServiceGrpcTransport, - ) + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.FeaturestoreServiceGrpcTransport,) def test_featurestore_service_base_transport_error(): @@ -5884,13 +5557,15 @@ def test_featurestore_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.FeaturestoreServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_featurestore_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.FeaturestoreServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -5899,26 +5574,26 @@ def test_featurestore_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_featurestore', - 'get_featurestore', - 'list_featurestores', - 'update_featurestore', - 'delete_featurestore', - 'create_entity_type', - 'get_entity_type', - 'list_entity_types', - 'update_entity_type', - 'delete_entity_type', - 'create_feature', - 'batch_create_features', - 'get_feature', - 'list_features', - 'update_feature', - 'delete_feature', - 'import_feature_values', - 'batch_read_feature_values', - 'search_features', - ) + "create_featurestore", + "get_featurestore", + "list_featurestores", + "update_featurestore", + "delete_featurestore", + "create_entity_type", + "get_entity_type", + "list_entity_types", + "update_entity_type", + "delete_entity_type", + "create_feature", + "batch_create_features", + "get_feature", + "list_features", + "update_feature", + "delete_feature", + "import_feature_values", + "batch_read_feature_values", + "search_features", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -5931,23 +5606,28 @@ def test_featurestore_service_base_transport(): def test_featurestore_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.FeaturestoreServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_featurestore_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.FeaturestoreServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.FeaturestoreServiceTransport() @@ -5956,11 +5636,11 @@ def test_featurestore_service_base_transport_with_adc(): def test_featurestore_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) FeaturestoreServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -5968,18 +5648,26 @@ def test_featurestore_service_auth_adc(): def test_featurestore_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.FeaturestoreServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.FeaturestoreServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.FeaturestoreServiceGrpcTransport, transports.FeaturestoreServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreServiceGrpcTransport, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + ], +) def test_featurestore_service_grpc_transport_client_cert_source_for_mtls( - transport_class + transport_class, ): cred = credentials.AnonymousCredentials() @@ -5989,15 +5677,13 @@ def test_featurestore_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -6012,38 +5698,40 @@ def test_featurestore_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_featurestore_service_host_no_port(): client = FeaturestoreServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_featurestore_service_host_with_port(): client = FeaturestoreServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_featurestore_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.FeaturestoreServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6051,12 +5739,11 @@ def test_featurestore_service_grpc_transport_channel(): def test_featurestore_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.FeaturestoreServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6065,12 +5752,22 @@ def test_featurestore_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.FeaturestoreServiceGrpcTransport, transports.FeaturestoreServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreServiceGrpcTransport, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + ], +) def test_featurestore_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -6079,7 +5776,7 @@ def test_featurestore_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -6095,9 +5792,7 @@ def test_featurestore_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6111,17 +5806,23 @@ def test_featurestore_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.FeaturestoreServiceGrpcTransport, transports.FeaturestoreServiceGrpcAsyncIOTransport]) -def test_featurestore_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.FeaturestoreServiceGrpcTransport, + transports.FeaturestoreServiceGrpcAsyncIOTransport, + ], +) +def test_featurestore_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -6138,9 +5839,7 @@ def test_featurestore_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6153,16 +5852,12 @@ def test_featurestore_service_transport_channel_mtls_with_adc( def test_featurestore_service_grpc_lro_client(): client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6170,16 +5865,12 @@ def test_featurestore_service_grpc_lro_client(): def test_featurestore_service_grpc_lro_async_client(): client = FeaturestoreServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6191,18 +5882,24 @@ def test_entity_type_path(): featurestore = "whelk" entity_type = "octopus" - expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, ) - actual = FeaturestoreServiceClient.entity_type_path(project, location, featurestore, entity_type) + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format( + project=project, + location=location, + featurestore=featurestore, + entity_type=entity_type, + ) + actual = FeaturestoreServiceClient.entity_type_path( + project, location, featurestore, entity_type + ) assert expected == actual def test_parse_entity_type_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "featurestore": "cuttlefish", - "entity_type": "mussel", - + "project": "oyster", + "location": "nudibranch", + "featurestore": "cuttlefish", + "entity_type": "mussel", } path = FeaturestoreServiceClient.entity_type_path(**expected) @@ -6210,6 +5907,7 @@ def test_parse_entity_type_path(): actual = FeaturestoreServiceClient.parse_entity_type_path(path) assert expected == actual + def test_feature_path(): project = "winkle" location = "nautilus" @@ -6217,19 +5915,26 @@ def test_feature_path(): entity_type = "abalone" feature = "squid" - expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}".format(project=project, location=location, featurestore=featurestore, entity_type=entity_type, feature=feature, ) - actual = FeaturestoreServiceClient.feature_path(project, location, featurestore, entity_type, feature) + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}".format( + project=project, + location=location, + featurestore=featurestore, + entity_type=entity_type, + feature=feature, + ) + actual = FeaturestoreServiceClient.feature_path( + project, location, featurestore, entity_type, feature + ) assert expected == actual def test_parse_feature_path(): expected = { - "project": "clam", - "location": "whelk", - "featurestore": "octopus", - "entity_type": "oyster", - "feature": "nudibranch", - + "project": "clam", + "location": "whelk", + "featurestore": "octopus", + "entity_type": "oyster", + "feature": "nudibranch", } path = FeaturestoreServiceClient.feature_path(**expected) @@ -6237,22 +5942,26 @@ def test_parse_feature_path(): actual = FeaturestoreServiceClient.parse_feature_path(path) assert expected == actual + def test_featurestore_path(): project = "cuttlefish" location = "mussel" featurestore = "winkle" - expected = "projects/{project}/locations/{location}/featurestores/{featurestore}".format(project=project, location=location, featurestore=featurestore, ) - actual = FeaturestoreServiceClient.featurestore_path(project, location, featurestore) + expected = "projects/{project}/locations/{location}/featurestores/{featurestore}".format( + project=project, location=location, featurestore=featurestore, + ) + actual = FeaturestoreServiceClient.featurestore_path( + project, location, featurestore + ) assert expected == actual def test_parse_featurestore_path(): expected = { - "project": "nautilus", - "location": "scallop", - "featurestore": "abalone", - + "project": "nautilus", + "location": "scallop", + "featurestore": "abalone", } path = FeaturestoreServiceClient.featurestore_path(**expected) @@ -6260,18 +5969,20 @@ def test_parse_featurestore_path(): actual = FeaturestoreServiceClient.parse_featurestore_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = FeaturestoreServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = FeaturestoreServiceClient.common_billing_account_path(**expected) @@ -6279,18 +5990,18 @@ def test_parse_common_billing_account_path(): actual = FeaturestoreServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = FeaturestoreServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = FeaturestoreServiceClient.common_folder_path(**expected) @@ -6298,18 +6009,18 @@ def test_parse_common_folder_path(): actual = FeaturestoreServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = FeaturestoreServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = FeaturestoreServiceClient.common_organization_path(**expected) @@ -6317,18 +6028,18 @@ def test_parse_common_organization_path(): actual = FeaturestoreServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = FeaturestoreServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = FeaturestoreServiceClient.common_project_path(**expected) @@ -6336,20 +6047,22 @@ def test_parse_common_project_path(): actual = FeaturestoreServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = FeaturestoreServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = FeaturestoreServiceClient.common_location_path(**expected) @@ -6361,17 +6074,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.FeaturestoreServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.FeaturestoreServiceTransport, "_prep_wrapped_messages" + ) as prep: client = FeaturestoreServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.FeaturestoreServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.FeaturestoreServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = FeaturestoreServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py index c8209a3cae..5e8e860b32 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import IndexEndpointServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import IndexEndpointServiceClient +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import ( + IndexEndpointServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import ( + IndexEndpointServiceClient, +) from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import transports from google.cloud.aiplatform_v1beta1.types import index_endpoint @@ -58,7 +62,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -69,36 +77,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert IndexEndpointServiceClient._get_default_mtls_endpoint(None) is None - assert IndexEndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert IndexEndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert IndexEndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert IndexEndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert IndexEndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + IndexEndpointServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + IndexEndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + IndexEndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + IndexEndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + IndexEndpointServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - IndexEndpointServiceClient, - IndexEndpointServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [IndexEndpointServiceClient, IndexEndpointServiceAsyncClient,] +) def test_index_endpoint_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - IndexEndpointServiceClient, - IndexEndpointServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [IndexEndpointServiceClient, IndexEndpointServiceAsyncClient,] +) def test_index_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -108,7 +133,7 @@ def test_index_endpoint_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_index_endpoint_service_client_get_transport_class(): @@ -122,29 +147,48 @@ def test_index_endpoint_service_client_get_transport_class(): assert transport == transports.IndexEndpointServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc"), - (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(IndexEndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceClient)) -@mock.patch.object(IndexEndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceAsyncClient)) -def test_index_endpoint_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + IndexEndpointServiceClient, + transports.IndexEndpointServiceGrpcTransport, + "grpc", + ), + ( + IndexEndpointServiceAsyncClient, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + IndexEndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceClient), +) +@mock.patch.object( + IndexEndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceAsyncClient), +) +def test_index_endpoint_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(IndexEndpointServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(IndexEndpointServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(IndexEndpointServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(IndexEndpointServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -160,7 +204,7 @@ def test_index_endpoint_service_client_client_options(client_class, transport_cl # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -176,7 +220,7 @@ def test_index_endpoint_service_client_client_options(client_class, transport_cl # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -196,13 +240,15 @@ def test_index_endpoint_service_client_client_options(client_class, transport_cl client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -215,26 +261,62 @@ def test_index_endpoint_service_client_client_options(client_class, transport_cl client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc", "true"), - (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc", "false"), - (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(IndexEndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceClient)) -@mock.patch.object(IndexEndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexEndpointServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + IndexEndpointServiceClient, + transports.IndexEndpointServiceGrpcTransport, + "grpc", + "true", + ), + ( + IndexEndpointServiceAsyncClient, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + IndexEndpointServiceClient, + transports.IndexEndpointServiceGrpcTransport, + "grpc", + "false", + ), + ( + IndexEndpointServiceAsyncClient, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + IndexEndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceClient), +) +@mock.patch.object( + IndexEndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexEndpointServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_index_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_index_endpoint_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -257,10 +339,18 @@ def test_index_endpoint_service_client_mtls_env_auto(client_class, transport_cla # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -281,9 +371,14 @@ def test_index_endpoint_service_client_mtls_env_auto(client_class, transport_cla ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -297,16 +392,27 @@ def test_index_endpoint_service_client_mtls_env_auto(client_class, transport_cla ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc"), - (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_index_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + IndexEndpointServiceClient, + transports.IndexEndpointServiceGrpcTransport, + "grpc", + ), + ( + IndexEndpointServiceAsyncClient, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_index_endpoint_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -319,16 +425,28 @@ def test_index_endpoint_service_client_client_options_scopes(client_class, trans client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (IndexEndpointServiceClient, transports.IndexEndpointServiceGrpcTransport, "grpc"), - (IndexEndpointServiceAsyncClient, transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_index_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + IndexEndpointServiceClient, + transports.IndexEndpointServiceGrpcTransport, + "grpc", + ), + ( + IndexEndpointServiceAsyncClient, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_index_endpoint_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -343,10 +461,12 @@ def test_index_endpoint_service_client_client_options_credentials_file(client_cl def test_index_endpoint_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = IndexEndpointServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -359,10 +479,12 @@ def test_index_endpoint_service_client_client_options_from_dict(): ) -def test_create_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.CreateIndexEndpointRequest): +def test_create_index_endpoint( + transport: str = "grpc", + request_type=index_endpoint_service.CreateIndexEndpointRequest, +): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -371,10 +493,10 @@ def test_create_index_endpoint(transport: str = 'grpc', request_type=index_endpo # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_index_endpoint), - '__call__') as call: + type(client.transport.create_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_index_endpoint(request) @@ -396,25 +518,27 @@ def test_create_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_index_endpoint), - '__call__') as call: + type(client.transport.create_index_endpoint), "__call__" + ) as call: client.create_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.CreateIndexEndpointRequest() + @pytest.mark.asyncio -async def test_create_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.CreateIndexEndpointRequest): +async def test_create_index_endpoint_async( + transport: str = "grpc_asyncio", + request_type=index_endpoint_service.CreateIndexEndpointRequest, +): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -423,11 +547,11 @@ async def test_create_index_endpoint_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_index_endpoint), - '__call__') as call: + type(client.transport.create_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_index_endpoint(request) @@ -448,20 +572,18 @@ async def test_create_index_endpoint_async_from_dict(): def test_create_index_endpoint_field_headers(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.CreateIndexEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_index_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_index_endpoint), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_index_endpoint(request) @@ -472,10 +594,7 @@ def test_create_index_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -487,13 +606,15 @@ async def test_create_index_endpoint_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.CreateIndexEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_index_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_index_endpoint), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_index_endpoint(request) @@ -504,29 +625,24 @@ async def test_create_index_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_index_endpoint_flattened(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_index_endpoint), - '__call__') as call: + type(client.transport.create_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_index_endpoint( - parent='parent_value', - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + parent="parent_value", + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -534,23 +650,23 @@ def test_create_index_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint( + name="name_value" + ) def test_create_index_endpoint_flattened_error(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_index_endpoint( index_endpoint_service.CreateIndexEndpointRequest(), - parent='parent_value', - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + parent="parent_value", + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), ) @@ -562,19 +678,19 @@ async def test_create_index_endpoint_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_index_endpoint), - '__call__') as call: + type(client.transport.create_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_index_endpoint( - parent='parent_value', - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + parent="parent_value", + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -582,9 +698,11 @@ async def test_create_index_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint( + name="name_value" + ) @pytest.mark.asyncio @@ -598,15 +716,16 @@ async def test_create_index_endpoint_flattened_error_async(): with pytest.raises(ValueError): await client.create_index_endpoint( index_endpoint_service.CreateIndexEndpointRequest(), - parent='parent_value', - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), + parent="parent_value", + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), ) -def test_get_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.GetIndexEndpointRequest): +def test_get_index_endpoint( + transport: str = "grpc", request_type=index_endpoint_service.GetIndexEndpointRequest +): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -615,20 +734,15 @@ def test_get_index_endpoint(transport: str = 'grpc', request_type=index_endpoint # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_index_endpoint), - '__call__') as call: + type(client.transport.get_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = index_endpoint.IndexEndpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - - network='network_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + network="network_value", ) response = client.get_index_endpoint(request) @@ -643,15 +757,15 @@ def test_get_index_endpoint(transport: str = 'grpc', request_type=index_endpoint assert isinstance(response, index_endpoint.IndexEndpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.network == 'network_value' + assert response.network == "network_value" def test_get_index_endpoint_from_dict(): @@ -662,25 +776,27 @@ def test_get_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_index_endpoint), - '__call__') as call: + type(client.transport.get_index_endpoint), "__call__" + ) as call: client.get_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.GetIndexEndpointRequest() + @pytest.mark.asyncio -async def test_get_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.GetIndexEndpointRequest): +async def test_get_index_endpoint_async( + transport: str = "grpc_asyncio", + request_type=index_endpoint_service.GetIndexEndpointRequest, +): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -689,16 +805,18 @@ async def test_get_index_endpoint_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_index_endpoint), - '__call__') as call: + type(client.transport.get_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint.IndexEndpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - network='network_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_endpoint.IndexEndpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + network="network_value", + ) + ) response = await client.get_index_endpoint(request) @@ -711,15 +829,15 @@ async def test_get_index_endpoint_async(transport: str = 'grpc_asyncio', request # Establish that the response is the type that we expect. assert isinstance(response, index_endpoint.IndexEndpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.network == 'network_value' + assert response.network == "network_value" @pytest.mark.asyncio @@ -728,19 +846,17 @@ async def test_get_index_endpoint_async_from_dict(): def test_get_index_endpoint_field_headers(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.GetIndexEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_index_endpoint), - '__call__') as call: + type(client.transport.get_index_endpoint), "__call__" + ) as call: call.return_value = index_endpoint.IndexEndpoint() client.get_index_endpoint(request) @@ -752,10 +868,7 @@ def test_get_index_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -767,13 +880,15 @@ async def test_get_index_endpoint_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.GetIndexEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_index_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint.IndexEndpoint()) + type(client.transport.get_index_endpoint), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_endpoint.IndexEndpoint() + ) await client.get_index_endpoint(request) @@ -784,49 +899,39 @@ async def test_get_index_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_index_endpoint_flattened(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_index_endpoint), - '__call__') as call: + type(client.transport.get_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = index_endpoint.IndexEndpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_index_endpoint( - name='name_value', - ) + client.get_index_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_index_endpoint_flattened_error(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_index_endpoint( - index_endpoint_service.GetIndexEndpointRequest(), - name='name_value', + index_endpoint_service.GetIndexEndpointRequest(), name="name_value", ) @@ -838,24 +943,24 @@ async def test_get_index_endpoint_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_index_endpoint), - '__call__') as call: + type(client.transport.get_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = index_endpoint.IndexEndpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint.IndexEndpoint()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_endpoint.IndexEndpoint() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_index_endpoint( - name='name_value', - ) + response = await client.get_index_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -868,15 +973,16 @@ async def test_get_index_endpoint_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_index_endpoint( - index_endpoint_service.GetIndexEndpointRequest(), - name='name_value', + index_endpoint_service.GetIndexEndpointRequest(), name="name_value", ) -def test_list_index_endpoints(transport: str = 'grpc', request_type=index_endpoint_service.ListIndexEndpointsRequest): +def test_list_index_endpoints( + transport: str = "grpc", + request_type=index_endpoint_service.ListIndexEndpointsRequest, +): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -885,12 +991,11 @@ def test_list_index_endpoints(transport: str = 'grpc', request_type=index_endpoi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = index_endpoint_service.ListIndexEndpointsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_index_endpoints(request) @@ -905,7 +1010,7 @@ def test_list_index_endpoints(transport: str = 'grpc', request_type=index_endpoi assert isinstance(response, pagers.ListIndexEndpointsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_index_endpoints_from_dict(): @@ -916,25 +1021,27 @@ def test_list_index_endpoints_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: client.list_index_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.ListIndexEndpointsRequest() + @pytest.mark.asyncio -async def test_list_index_endpoints_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.ListIndexEndpointsRequest): +async def test_list_index_endpoints_async( + transport: str = "grpc_asyncio", + request_type=index_endpoint_service.ListIndexEndpointsRequest, +): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -943,12 +1050,14 @@ async def test_list_index_endpoints_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint_service.ListIndexEndpointsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_endpoint_service.ListIndexEndpointsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_index_endpoints(request) @@ -961,7 +1070,7 @@ async def test_list_index_endpoints_async(transport: str = 'grpc_asyncio', reque # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListIndexEndpointsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -970,19 +1079,17 @@ async def test_list_index_endpoints_async_from_dict(): def test_list_index_endpoints_field_headers(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.ListIndexEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: call.return_value = index_endpoint_service.ListIndexEndpointsResponse() client.list_index_endpoints(request) @@ -994,10 +1101,7 @@ def test_list_index_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1009,13 +1113,15 @@ async def test_list_index_endpoints_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.ListIndexEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint_service.ListIndexEndpointsResponse()) + type(client.transport.list_index_endpoints), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_endpoint_service.ListIndexEndpointsResponse() + ) await client.list_index_endpoints(request) @@ -1026,49 +1132,39 @@ async def test_list_index_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_index_endpoints_flattened(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = index_endpoint_service.ListIndexEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_index_endpoints( - parent='parent_value', - ) + client.list_index_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_index_endpoints_flattened_error(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_index_endpoints( - index_endpoint_service.ListIndexEndpointsRequest(), - parent='parent_value', + index_endpoint_service.ListIndexEndpointsRequest(), parent="parent_value", ) @@ -1080,24 +1176,24 @@ async def test_list_index_endpoints_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = index_endpoint_service.ListIndexEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_endpoint_service.ListIndexEndpointsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_endpoint_service.ListIndexEndpointsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_index_endpoints( - parent='parent_value', - ) + response = await client.list_index_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -1110,20 +1206,17 @@ async def test_list_index_endpoints_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_index_endpoints( - index_endpoint_service.ListIndexEndpointsRequest(), - parent='parent_value', + index_endpoint_service.ListIndexEndpointsRequest(), parent="parent_value", ) def test_list_index_endpoints_pager(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( index_endpoint_service.ListIndexEndpointsResponse( @@ -1132,17 +1225,14 @@ def test_list_index_endpoints_pager(): index_endpoint.IndexEndpoint(), index_endpoint.IndexEndpoint(), ], - next_page_token='abc', + next_page_token="abc", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[], - next_page_token='def', + index_endpoints=[], next_page_token="def", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[ - index_endpoint.IndexEndpoint(), - ], - next_page_token='ghi', + index_endpoints=[index_endpoint.IndexEndpoint(),], + next_page_token="ghi", ), index_endpoint_service.ListIndexEndpointsResponse( index_endpoints=[ @@ -1155,9 +1245,7 @@ def test_list_index_endpoints_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_index_endpoints(request={}) @@ -1165,18 +1253,16 @@ def test_list_index_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, index_endpoint.IndexEndpoint) - for i in results) + assert all(isinstance(i, index_endpoint.IndexEndpoint) for i in results) + def test_list_index_endpoints_pages(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__') as call: + type(client.transport.list_index_endpoints), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( index_endpoint_service.ListIndexEndpointsResponse( @@ -1185,17 +1271,14 @@ def test_list_index_endpoints_pages(): index_endpoint.IndexEndpoint(), index_endpoint.IndexEndpoint(), ], - next_page_token='abc', + next_page_token="abc", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[], - next_page_token='def', + index_endpoints=[], next_page_token="def", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[ - index_endpoint.IndexEndpoint(), - ], - next_page_token='ghi', + index_endpoints=[index_endpoint.IndexEndpoint(),], + next_page_token="ghi", ), index_endpoint_service.ListIndexEndpointsResponse( index_endpoints=[ @@ -1206,9 +1289,10 @@ def test_list_index_endpoints_pages(): RuntimeError, ) pages = list(client.list_index_endpoints(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_index_endpoints_async_pager(): client = IndexEndpointServiceAsyncClient( @@ -1217,8 +1301,10 @@ async def test_list_index_endpoints_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_index_endpoints), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( index_endpoint_service.ListIndexEndpointsResponse( @@ -1227,17 +1313,14 @@ async def test_list_index_endpoints_async_pager(): index_endpoint.IndexEndpoint(), index_endpoint.IndexEndpoint(), ], - next_page_token='abc', + next_page_token="abc", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[], - next_page_token='def', + index_endpoints=[], next_page_token="def", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[ - index_endpoint.IndexEndpoint(), - ], - next_page_token='ghi', + index_endpoints=[index_endpoint.IndexEndpoint(),], + next_page_token="ghi", ), index_endpoint_service.ListIndexEndpointsResponse( index_endpoints=[ @@ -1248,14 +1331,14 @@ async def test_list_index_endpoints_async_pager(): RuntimeError, ) async_pager = await client.list_index_endpoints(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, index_endpoint.IndexEndpoint) - for i in responses) + assert all(isinstance(i, index_endpoint.IndexEndpoint) for i in responses) + @pytest.mark.asyncio async def test_list_index_endpoints_async_pages(): @@ -1265,8 +1348,10 @@ async def test_list_index_endpoints_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_index_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_index_endpoints), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( index_endpoint_service.ListIndexEndpointsResponse( @@ -1275,17 +1360,14 @@ async def test_list_index_endpoints_async_pages(): index_endpoint.IndexEndpoint(), index_endpoint.IndexEndpoint(), ], - next_page_token='abc', + next_page_token="abc", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[], - next_page_token='def', + index_endpoints=[], next_page_token="def", ), index_endpoint_service.ListIndexEndpointsResponse( - index_endpoints=[ - index_endpoint.IndexEndpoint(), - ], - next_page_token='ghi', + index_endpoints=[index_endpoint.IndexEndpoint(),], + next_page_token="ghi", ), index_endpoint_service.ListIndexEndpointsResponse( index_endpoints=[ @@ -1298,14 +1380,16 @@ async def test_list_index_endpoints_async_pages(): pages = [] async for page_ in (await client.list_index_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.UpdateIndexEndpointRequest): +def test_update_index_endpoint( + transport: str = "grpc", + request_type=index_endpoint_service.UpdateIndexEndpointRequest, +): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1314,20 +1398,15 @@ def test_update_index_endpoint(transport: str = 'grpc', request_type=index_endpo # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_index_endpoint), - '__call__') as call: + type(client.transport.update_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_index_endpoint.IndexEndpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - - network='network_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + network="network_value", ) response = client.update_index_endpoint(request) @@ -1342,15 +1421,15 @@ def test_update_index_endpoint(transport: str = 'grpc', request_type=index_endpo assert isinstance(response, gca_index_endpoint.IndexEndpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.network == 'network_value' + assert response.network == "network_value" def test_update_index_endpoint_from_dict(): @@ -1361,25 +1440,27 @@ def test_update_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_index_endpoint), - '__call__') as call: + type(client.transport.update_index_endpoint), "__call__" + ) as call: client.update_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() + @pytest.mark.asyncio -async def test_update_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.UpdateIndexEndpointRequest): +async def test_update_index_endpoint_async( + transport: str = "grpc_asyncio", + request_type=index_endpoint_service.UpdateIndexEndpointRequest, +): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1388,16 +1469,18 @@ async def test_update_index_endpoint_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_index_endpoint), - '__call__') as call: + type(client.transport.update_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_index_endpoint.IndexEndpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - network='network_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_index_endpoint.IndexEndpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + network="network_value", + ) + ) response = await client.update_index_endpoint(request) @@ -1410,15 +1493,15 @@ async def test_update_index_endpoint_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, gca_index_endpoint.IndexEndpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.network == 'network_value' + assert response.network == "network_value" @pytest.mark.asyncio @@ -1427,19 +1510,17 @@ async def test_update_index_endpoint_async_from_dict(): def test_update_index_endpoint_field_headers(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.UpdateIndexEndpointRequest() - request.index_endpoint.name = 'index_endpoint.name/value' + request.index_endpoint.name = "index_endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_index_endpoint), - '__call__') as call: + type(client.transport.update_index_endpoint), "__call__" + ) as call: call.return_value = gca_index_endpoint.IndexEndpoint() client.update_index_endpoint(request) @@ -1452,9 +1533,9 @@ def test_update_index_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'index_endpoint.name=index_endpoint.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "index_endpoint.name=index_endpoint.name/value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1466,13 +1547,15 @@ async def test_update_index_endpoint_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.UpdateIndexEndpointRequest() - request.index_endpoint.name = 'index_endpoint.name/value' + request.index_endpoint.name = "index_endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_index_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_index_endpoint.IndexEndpoint()) + type(client.transport.update_index_endpoint), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_index_endpoint.IndexEndpoint() + ) await client.update_index_endpoint(request) @@ -1484,28 +1567,26 @@ async def test_update_index_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'index_endpoint.name=index_endpoint.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "index_endpoint.name=index_endpoint.name/value", + ) in kw["metadata"] def test_update_index_endpoint_flattened(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_index_endpoint), - '__call__') as call: + type(client.transport.update_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_index_endpoint.IndexEndpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_index_endpoint( - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1513,23 +1594,23 @@ def test_update_index_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_index_endpoint_flattened_error(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_index_endpoint( index_endpoint_service.UpdateIndexEndpointRequest(), - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1541,17 +1622,19 @@ async def test_update_index_endpoint_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_index_endpoint), - '__call__') as call: + type(client.transport.update_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_index_endpoint.IndexEndpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_index_endpoint.IndexEndpoint()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_index_endpoint.IndexEndpoint() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_index_endpoint( - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1559,9 +1642,11 @@ async def test_update_index_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint(name='name_value') + assert args[0].index_endpoint == gca_index_endpoint.IndexEndpoint( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -1575,15 +1660,17 @@ async def test_update_index_endpoint_flattened_error_async(): with pytest.raises(ValueError): await client.update_index_endpoint( index_endpoint_service.UpdateIndexEndpointRequest(), - index_endpoint=gca_index_endpoint.IndexEndpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index_endpoint=gca_index_endpoint.IndexEndpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_index_endpoint(transport: str = 'grpc', request_type=index_endpoint_service.DeleteIndexEndpointRequest): +def test_delete_index_endpoint( + transport: str = "grpc", + request_type=index_endpoint_service.DeleteIndexEndpointRequest, +): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1592,10 +1679,10 @@ def test_delete_index_endpoint(transport: str = 'grpc', request_type=index_endpo # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_index_endpoint), - '__call__') as call: + type(client.transport.delete_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_index_endpoint(request) @@ -1617,25 +1704,27 @@ def test_delete_index_endpoint_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_index_endpoint), - '__call__') as call: + type(client.transport.delete_index_endpoint), "__call__" + ) as call: client.delete_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.DeleteIndexEndpointRequest() + @pytest.mark.asyncio -async def test_delete_index_endpoint_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.DeleteIndexEndpointRequest): +async def test_delete_index_endpoint_async( + transport: str = "grpc_asyncio", + request_type=index_endpoint_service.DeleteIndexEndpointRequest, +): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1644,11 +1733,11 @@ async def test_delete_index_endpoint_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_index_endpoint), - '__call__') as call: + type(client.transport.delete_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_index_endpoint(request) @@ -1669,20 +1758,18 @@ async def test_delete_index_endpoint_async_from_dict(): def test_delete_index_endpoint_field_headers(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.DeleteIndexEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_index_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_index_endpoint), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_index_endpoint(request) @@ -1693,10 +1780,7 @@ def test_delete_index_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1708,13 +1792,15 @@ async def test_delete_index_endpoint_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.DeleteIndexEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_index_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_index_endpoint), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_index_endpoint(request) @@ -1725,49 +1811,39 @@ async def test_delete_index_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_index_endpoint_flattened(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_index_endpoint), - '__call__') as call: + type(client.transport.delete_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_index_endpoint( - name='name_value', - ) + client.delete_index_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_index_endpoint_flattened_error(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_index_endpoint( - index_endpoint_service.DeleteIndexEndpointRequest(), - name='name_value', + index_endpoint_service.DeleteIndexEndpointRequest(), name="name_value", ) @@ -1779,26 +1855,24 @@ async def test_delete_index_endpoint_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_index_endpoint), - '__call__') as call: + type(client.transport.delete_index_endpoint), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_index_endpoint( - name='name_value', - ) + response = await client.delete_index_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -1811,15 +1885,15 @@ async def test_delete_index_endpoint_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_index_endpoint( - index_endpoint_service.DeleteIndexEndpointRequest(), - name='name_value', + index_endpoint_service.DeleteIndexEndpointRequest(), name="name_value", ) -def test_deploy_index(transport: str = 'grpc', request_type=index_endpoint_service.DeployIndexRequest): +def test_deploy_index( + transport: str = "grpc", request_type=index_endpoint_service.DeployIndexRequest +): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1827,11 +1901,9 @@ def test_deploy_index(transport: str = 'grpc', request_type=index_endpoint_servi request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.deploy_index(request) @@ -1853,25 +1925,25 @@ def test_deploy_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: client.deploy_index() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.DeployIndexRequest() + @pytest.mark.asyncio -async def test_deploy_index_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.DeployIndexRequest): +async def test_deploy_index_async( + transport: str = "grpc_asyncio", + request_type=index_endpoint_service.DeployIndexRequest, +): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1879,12 +1951,10 @@ async def test_deploy_index_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.deploy_index(request) @@ -1905,20 +1975,16 @@ async def test_deploy_index_async_from_dict(): def test_deploy_index_field_headers(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.DeployIndexRequest() - request.index_endpoint = 'index_endpoint/value' + request.index_endpoint = "index_endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_index), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.deploy_index(request) @@ -1929,10 +1995,9 @@ def test_deploy_index_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'index_endpoint=index_endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "index_endpoint=index_endpoint/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio @@ -1944,13 +2009,13 @@ async def test_deploy_index_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.DeployIndexRequest() - request.index_endpoint = 'index_endpoint/value' + request.index_endpoint = "index_endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_index), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.deploy_index(request) @@ -1961,29 +2026,24 @@ async def test_deploy_index_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'index_endpoint=index_endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "index_endpoint=index_endpoint/value",) in kw[ + "metadata" + ] def test_deploy_index_flattened(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_index( - index_endpoint='index_endpoint_value', - deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + index_endpoint="index_endpoint_value", + deployed_index=gca_index_endpoint.DeployedIndex(id="id_value"), ) # Establish that the underlying call was made with the expected @@ -1991,23 +2051,21 @@ def test_deploy_index_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].index_endpoint == 'index_endpoint_value' + assert args[0].index_endpoint == "index_endpoint_value" - assert args[0].deployed_index == gca_index_endpoint.DeployedIndex(id='id_value') + assert args[0].deployed_index == gca_index_endpoint.DeployedIndex(id="id_value") def test_deploy_index_flattened_error(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_index( index_endpoint_service.DeployIndexRequest(), - index_endpoint='index_endpoint_value', - deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + index_endpoint="index_endpoint_value", + deployed_index=gca_index_endpoint.DeployedIndex(id="id_value"), ) @@ -2018,20 +2076,18 @@ async def test_deploy_index_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_index( - index_endpoint='index_endpoint_value', - deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + index_endpoint="index_endpoint_value", + deployed_index=gca_index_endpoint.DeployedIndex(id="id_value"), ) # Establish that the underlying call was made with the expected @@ -2039,9 +2095,9 @@ async def test_deploy_index_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].index_endpoint == 'index_endpoint_value' + assert args[0].index_endpoint == "index_endpoint_value" - assert args[0].deployed_index == gca_index_endpoint.DeployedIndex(id='id_value') + assert args[0].deployed_index == gca_index_endpoint.DeployedIndex(id="id_value") @pytest.mark.asyncio @@ -2055,15 +2111,16 @@ async def test_deploy_index_flattened_error_async(): with pytest.raises(ValueError): await client.deploy_index( index_endpoint_service.DeployIndexRequest(), - index_endpoint='index_endpoint_value', - deployed_index=gca_index_endpoint.DeployedIndex(id='id_value'), + index_endpoint="index_endpoint_value", + deployed_index=gca_index_endpoint.DeployedIndex(id="id_value"), ) -def test_undeploy_index(transport: str = 'grpc', request_type=index_endpoint_service.UndeployIndexRequest): +def test_undeploy_index( + transport: str = "grpc", request_type=index_endpoint_service.UndeployIndexRequest +): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2071,11 +2128,9 @@ def test_undeploy_index(transport: str = 'grpc', request_type=index_endpoint_ser request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.undeploy_index(request) @@ -2097,25 +2152,25 @@ def test_undeploy_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: client.undeploy_index() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.UndeployIndexRequest() + @pytest.mark.asyncio -async def test_undeploy_index_async(transport: str = 'grpc_asyncio', request_type=index_endpoint_service.UndeployIndexRequest): +async def test_undeploy_index_async( + transport: str = "grpc_asyncio", + request_type=index_endpoint_service.UndeployIndexRequest, +): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2123,12 +2178,10 @@ async def test_undeploy_index_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.undeploy_index(request) @@ -2149,20 +2202,16 @@ async def test_undeploy_index_async_from_dict(): def test_undeploy_index_field_headers(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.UndeployIndexRequest() - request.index_endpoint = 'index_endpoint/value' + request.index_endpoint = "index_endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_index), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.undeploy_index(request) @@ -2173,10 +2222,9 @@ def test_undeploy_index_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'index_endpoint=index_endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "index_endpoint=index_endpoint/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio @@ -2188,13 +2236,13 @@ async def test_undeploy_index_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_endpoint_service.UndeployIndexRequest() - request.index_endpoint = 'index_endpoint/value' + request.index_endpoint = "index_endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_index), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.undeploy_index(request) @@ -2205,29 +2253,24 @@ async def test_undeploy_index_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'index_endpoint=index_endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "index_endpoint=index_endpoint/value",) in kw[ + "metadata" + ] def test_undeploy_index_flattened(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_index( - index_endpoint='index_endpoint_value', - deployed_index_id='deployed_index_id_value', + index_endpoint="index_endpoint_value", + deployed_index_id="deployed_index_id_value", ) # Establish that the underlying call was made with the expected @@ -2235,23 +2278,21 @@ def test_undeploy_index_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].index_endpoint == 'index_endpoint_value' + assert args[0].index_endpoint == "index_endpoint_value" - assert args[0].deployed_index_id == 'deployed_index_id_value' + assert args[0].deployed_index_id == "deployed_index_id_value" def test_undeploy_index_flattened_error(): - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_index( index_endpoint_service.UndeployIndexRequest(), - index_endpoint='index_endpoint_value', - deployed_index_id='deployed_index_id_value', + index_endpoint="index_endpoint_value", + deployed_index_id="deployed_index_id_value", ) @@ -2262,20 +2303,18 @@ async def test_undeploy_index_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_index), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_index( - index_endpoint='index_endpoint_value', - deployed_index_id='deployed_index_id_value', + index_endpoint="index_endpoint_value", + deployed_index_id="deployed_index_id_value", ) # Establish that the underlying call was made with the expected @@ -2283,9 +2322,9 @@ async def test_undeploy_index_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].index_endpoint == 'index_endpoint_value' + assert args[0].index_endpoint == "index_endpoint_value" - assert args[0].deployed_index_id == 'deployed_index_id_value' + assert args[0].deployed_index_id == "deployed_index_id_value" @pytest.mark.asyncio @@ -2299,8 +2338,8 @@ async def test_undeploy_index_flattened_error_async(): with pytest.raises(ValueError): await client.undeploy_index( index_endpoint_service.UndeployIndexRequest(), - index_endpoint='index_endpoint_value', - deployed_index_id='deployed_index_id_value', + index_endpoint="index_endpoint_value", + deployed_index_id="deployed_index_id_value", ) @@ -2311,8 +2350,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2331,8 +2369,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = IndexEndpointServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -2360,13 +2397,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.IndexEndpointServiceGrpcTransport, - transports.IndexEndpointServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.IndexEndpointServiceGrpcTransport, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2374,13 +2414,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.IndexEndpointServiceGrpcTransport, - ) + client = IndexEndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.IndexEndpointServiceGrpcTransport,) def test_index_endpoint_service_base_transport_error(): @@ -2388,13 +2423,15 @@ def test_index_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.IndexEndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_index_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.IndexEndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2403,14 +2440,14 @@ def test_index_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_index_endpoint', - 'get_index_endpoint', - 'list_index_endpoints', - 'update_index_endpoint', - 'delete_index_endpoint', - 'deploy_index', - 'undeploy_index', - ) + "create_index_endpoint", + "get_index_endpoint", + "list_index_endpoints", + "update_index_endpoint", + "delete_index_endpoint", + "deploy_index", + "undeploy_index", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2423,23 +2460,28 @@ def test_index_endpoint_service_base_transport(): def test_index_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.IndexEndpointServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_index_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.IndexEndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.IndexEndpointServiceTransport() @@ -2448,11 +2490,11 @@ def test_index_endpoint_service_base_transport_with_adc(): def test_index_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) IndexEndpointServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -2460,18 +2502,26 @@ def test_index_endpoint_service_auth_adc(): def test_index_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.IndexEndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.IndexEndpointServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.IndexEndpointServiceGrpcTransport, transports.IndexEndpointServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.IndexEndpointServiceGrpcTransport, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + ], +) def test_index_endpoint_service_grpc_transport_client_cert_source_for_mtls( - transport_class + transport_class, ): cred = credentials.AnonymousCredentials() @@ -2481,15 +2531,13 @@ def test_index_endpoint_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2504,38 +2552,40 @@ def test_index_endpoint_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_index_endpoint_service_host_no_port(): client = IndexEndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_index_endpoint_service_host_with_port(): client = IndexEndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_index_endpoint_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.IndexEndpointServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2543,12 +2593,11 @@ def test_index_endpoint_service_grpc_transport_channel(): def test_index_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.IndexEndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2557,12 +2606,22 @@ def test_index_endpoint_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.IndexEndpointServiceGrpcTransport, transports.IndexEndpointServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.IndexEndpointServiceGrpcTransport, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + ], +) def test_index_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2571,7 +2630,7 @@ def test_index_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2587,9 +2646,7 @@ def test_index_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2603,17 +2660,23 @@ def test_index_endpoint_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.IndexEndpointServiceGrpcTransport, transports.IndexEndpointServiceGrpcAsyncIOTransport]) -def test_index_endpoint_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.IndexEndpointServiceGrpcTransport, + transports.IndexEndpointServiceGrpcAsyncIOTransport, + ], +) +def test_index_endpoint_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2630,9 +2693,7 @@ def test_index_endpoint_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2645,16 +2706,12 @@ def test_index_endpoint_service_transport_channel_mtls_with_adc( def test_index_endpoint_service_grpc_lro_client(): client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2662,16 +2719,12 @@ def test_index_endpoint_service_grpc_lro_client(): def test_index_endpoint_service_grpc_lro_async_client(): client = IndexEndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2682,17 +2735,18 @@ def test_index_path(): location = "clam" index = "whelk" - expected = "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + expected = "projects/{project}/locations/{location}/indexes/{index}".format( + project=project, location=location, index=index, + ) actual = IndexEndpointServiceClient.index_path(project, location, index) assert expected == actual def test_parse_index_path(): expected = { - "project": "octopus", - "location": "oyster", - "index": "nudibranch", - + "project": "octopus", + "location": "oyster", + "index": "nudibranch", } path = IndexEndpointServiceClient.index_path(**expected) @@ -2700,22 +2754,26 @@ def test_parse_index_path(): actual = IndexEndpointServiceClient.parse_index_path(path) assert expected == actual + def test_index_endpoint_path(): project = "cuttlefish" location = "mussel" index_endpoint = "winkle" - expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) - actual = IndexEndpointServiceClient.index_endpoint_path(project, location, index_endpoint) + expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( + project=project, location=location, index_endpoint=index_endpoint, + ) + actual = IndexEndpointServiceClient.index_endpoint_path( + project, location, index_endpoint + ) assert expected == actual def test_parse_index_endpoint_path(): expected = { - "project": "nautilus", - "location": "scallop", - "index_endpoint": "abalone", - + "project": "nautilus", + "location": "scallop", + "index_endpoint": "abalone", } path = IndexEndpointServiceClient.index_endpoint_path(**expected) @@ -2723,22 +2781,26 @@ def test_parse_index_endpoint_path(): actual = IndexEndpointServiceClient.parse_index_endpoint_path(path) assert expected == actual + def test_index_endpoint_path(): project = "squid" location = "clam" index_endpoint = "whelk" - expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) - actual = IndexEndpointServiceClient.index_endpoint_path(project, location, index_endpoint) + expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( + project=project, location=location, index_endpoint=index_endpoint, + ) + actual = IndexEndpointServiceClient.index_endpoint_path( + project, location, index_endpoint + ) assert expected == actual def test_parse_index_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "index_endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "index_endpoint": "nudibranch", } path = IndexEndpointServiceClient.index_endpoint_path(**expected) @@ -2746,18 +2808,20 @@ def test_parse_index_endpoint_path(): actual = IndexEndpointServiceClient.parse_index_endpoint_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = IndexEndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = IndexEndpointServiceClient.common_billing_account_path(**expected) @@ -2765,18 +2829,18 @@ def test_parse_common_billing_account_path(): actual = IndexEndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = IndexEndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = IndexEndpointServiceClient.common_folder_path(**expected) @@ -2784,18 +2848,18 @@ def test_parse_common_folder_path(): actual = IndexEndpointServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = IndexEndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = IndexEndpointServiceClient.common_organization_path(**expected) @@ -2803,18 +2867,18 @@ def test_parse_common_organization_path(): actual = IndexEndpointServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = IndexEndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = IndexEndpointServiceClient.common_project_path(**expected) @@ -2822,20 +2886,22 @@ def test_parse_common_project_path(): actual = IndexEndpointServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = IndexEndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = IndexEndpointServiceClient.common_location_path(**expected) @@ -2847,17 +2913,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.IndexEndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.IndexEndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: client = IndexEndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.IndexEndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.IndexEndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = IndexEndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py index 416b2087cc..5d9586883e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.index_service import IndexServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.index_service import ( + IndexServiceAsyncClient, +) from google.cloud.aiplatform_v1beta1.services.index_service import IndexServiceClient from google.cloud.aiplatform_v1beta1.services.index_service import pagers from google.cloud.aiplatform_v1beta1.services.index_service import transports @@ -59,7 +61,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -70,36 +76,45 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert IndexServiceClient._get_default_mtls_endpoint(None) is None - assert IndexServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert IndexServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert IndexServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert IndexServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + IndexServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + IndexServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + IndexServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + IndexServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert IndexServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [ - IndexServiceClient, - IndexServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [IndexServiceClient, IndexServiceAsyncClient,]) def test_index_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - IndexServiceClient, - IndexServiceAsyncClient, -]) +@pytest.mark.parametrize("client_class", [IndexServiceClient, IndexServiceAsyncClient,]) def test_index_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -109,7 +124,7 @@ def test_index_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_index_service_client_get_transport_class(): @@ -123,29 +138,42 @@ def test_index_service_client_get_transport_class(): assert transport == transports.IndexServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), - (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient)) -@mock.patch.object(IndexServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceAsyncClient)) -def test_index_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + ( + IndexServiceAsyncClient, + transports.IndexServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient) +) +@mock.patch.object( + IndexServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexServiceAsyncClient), +) +def test_index_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(IndexServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(IndexServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(IndexServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(IndexServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -161,7 +189,7 @@ def test_index_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -177,7 +205,7 @@ def test_index_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -197,13 +225,15 @@ def test_index_service_client_client_options(client_class, transport_class, tran client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -216,26 +246,50 @@ def test_index_service_client_client_options(client_class, transport_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc", "true"), - (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc", "false"), - (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient)) -@mock.patch.object(IndexServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc", "true"), + ( + IndexServiceAsyncClient, + transports.IndexServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc", "false"), + ( + IndexServiceAsyncClient, + transports.IndexServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + IndexServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(IndexServiceClient) +) +@mock.patch.object( + IndexServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(IndexServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_index_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_index_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -258,10 +312,18 @@ def test_index_service_client_mtls_env_auto(client_class, transport_class, trans # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -282,9 +344,14 @@ def test_index_service_client_mtls_env_auto(client_class, transport_class, trans ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -298,16 +365,23 @@ def test_index_service_client_mtls_env_auto(client_class, transport_class, trans ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), - (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_index_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + ( + IndexServiceAsyncClient, + transports.IndexServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_index_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -320,16 +394,24 @@ def test_index_service_client_client_options_scopes(client_class, transport_clas client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), - (IndexServiceAsyncClient, transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_index_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (IndexServiceClient, transports.IndexServiceGrpcTransport, "grpc"), + ( + IndexServiceAsyncClient, + transports.IndexServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_index_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -344,11 +426,11 @@ def test_index_service_client_client_options_credentials_file(client_class, tran def test_index_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = IndexServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = IndexServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -360,10 +442,11 @@ def test_index_service_client_client_options_from_dict(): ) -def test_create_index(transport: str = 'grpc', request_type=index_service.CreateIndexRequest): +def test_create_index( + transport: str = "grpc", request_type=index_service.CreateIndexRequest +): client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -371,11 +454,9 @@ def test_create_index(transport: str = 'grpc', request_type=index_service.Create request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_index), - '__call__') as call: + with mock.patch.object(type(client.transport.create_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_index(request) @@ -397,25 +478,24 @@ def test_create_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_index), - '__call__') as call: + with mock.patch.object(type(client.transport.create_index), "__call__") as call: client.create_index() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_service.CreateIndexRequest() + @pytest.mark.asyncio -async def test_create_index_async(transport: str = 'grpc_asyncio', request_type=index_service.CreateIndexRequest): +async def test_create_index_async( + transport: str = "grpc_asyncio", request_type=index_service.CreateIndexRequest +): client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -423,12 +503,10 @@ async def test_create_index_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_index), - '__call__') as call: + with mock.patch.object(type(client.transport.create_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_index(request) @@ -449,20 +527,16 @@ async def test_create_index_async_from_dict(): def test_create_index_field_headers(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.CreateIndexRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_index), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_index(request) @@ -473,28 +547,23 @@ def test_create_index_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_index_field_headers_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.CreateIndexRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_index), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_index(request) @@ -505,29 +574,21 @@ async def test_create_index_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_index_flattened(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_index), - '__call__') as call: + with mock.patch.object(type(client.transport.create_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_index( - parent='parent_value', - index=gca_index.Index(name='name_value'), + parent="parent_value", index=gca_index.Index(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -535,47 +596,40 @@ def test_create_index_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].index == gca_index.Index(name='name_value') + assert args[0].index == gca_index.Index(name="name_value") def test_create_index_flattened_error(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_index( index_service.CreateIndexRequest(), - parent='parent_value', - index=gca_index.Index(name='name_value'), + parent="parent_value", + index=gca_index.Index(name="name_value"), ) @pytest.mark.asyncio async def test_create_index_flattened_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_index), - '__call__') as call: + with mock.patch.object(type(client.transport.create_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_index( - parent='parent_value', - index=gca_index.Index(name='name_value'), + parent="parent_value", index=gca_index.Index(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -583,31 +637,28 @@ async def test_create_index_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].index == gca_index.Index(name='name_value') + assert args[0].index == gca_index.Index(name="name_value") @pytest.mark.asyncio async def test_create_index_flattened_error_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_index( index_service.CreateIndexRequest(), - parent='parent_value', - index=gca_index.Index(name='name_value'), + parent="parent_value", + index=gca_index.Index(name="name_value"), ) -def test_get_index(transport: str = 'grpc', request_type=index_service.GetIndexRequest): +def test_get_index(transport: str = "grpc", request_type=index_service.GetIndexRequest): client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -615,21 +666,14 @@ def test_get_index(transport: str = 'grpc', request_type=index_service.GetIndexR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_index), - '__call__') as call: + with mock.patch.object(type(client.transport.get_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = index.Index( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.get_index(request) @@ -644,15 +688,15 @@ def test_get_index(transport: str = 'grpc', request_type=index_service.GetIndexR assert isinstance(response, index.Index) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_index_from_dict(): @@ -663,25 +707,24 @@ def test_get_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_index), - '__call__') as call: + with mock.patch.object(type(client.transport.get_index), "__call__") as call: client.get_index() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_service.GetIndexRequest() + @pytest.mark.asyncio -async def test_get_index_async(transport: str = 'grpc_asyncio', request_type=index_service.GetIndexRequest): +async def test_get_index_async( + transport: str = "grpc_asyncio", request_type=index_service.GetIndexRequest +): client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -689,17 +732,17 @@ async def test_get_index_async(transport: str = 'grpc_asyncio', request_type=ind request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_index), - '__call__') as call: + with mock.patch.object(type(client.transport.get_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index.Index( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index.Index( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.get_index(request) @@ -712,15 +755,15 @@ async def test_get_index_async(transport: str = 'grpc_asyncio', request_type=ind # Establish that the response is the type that we expect. assert isinstance(response, index.Index) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -729,19 +772,15 @@ async def test_get_index_async_from_dict(): def test_get_index_field_headers(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.GetIndexRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_index), - '__call__') as call: + with mock.patch.object(type(client.transport.get_index), "__call__") as call: call.return_value = index.Index() client.get_index(request) @@ -753,27 +792,20 @@ def test_get_index_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_index_field_headers_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.GetIndexRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_index), - '__call__') as call: + with mock.patch.object(type(client.transport.get_index), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index.Index()) await client.get_index(request) @@ -785,99 +817,79 @@ async def test_get_index_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_index_flattened(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_index), - '__call__') as call: + with mock.patch.object(type(client.transport.get_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = index.Index() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_index( - name='name_value', - ) + client.get_index(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_index_flattened_error(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_index( - index_service.GetIndexRequest(), - name='name_value', + index_service.GetIndexRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_index_flattened_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_index), - '__call__') as call: + with mock.patch.object(type(client.transport.get_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = index.Index() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index.Index()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_index( - name='name_value', - ) + response = await client.get_index(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_index_flattened_error_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_index( - index_service.GetIndexRequest(), - name='name_value', + index_service.GetIndexRequest(), name="name_value", ) -def test_list_indexes(transport: str = 'grpc', request_type=index_service.ListIndexesRequest): +def test_list_indexes( + transport: str = "grpc", request_type=index_service.ListIndexesRequest +): client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -885,13 +897,10 @@ def test_list_indexes(transport: str = 'grpc', request_type=index_service.ListIn request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = index_service.ListIndexesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_indexes(request) @@ -906,7 +915,7 @@ def test_list_indexes(transport: str = 'grpc', request_type=index_service.ListIn assert isinstance(response, pagers.ListIndexesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_indexes_from_dict(): @@ -917,25 +926,24 @@ def test_list_indexes_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: client.list_indexes() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_service.ListIndexesRequest() + @pytest.mark.asyncio -async def test_list_indexes_async(transport: str = 'grpc_asyncio', request_type=index_service.ListIndexesRequest): +async def test_list_indexes_async( + transport: str = "grpc_asyncio", request_type=index_service.ListIndexesRequest +): client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -943,13 +951,11 @@ async def test_list_indexes_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_service.ListIndexesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_service.ListIndexesResponse(next_page_token="next_page_token_value",) + ) response = await client.list_indexes(request) @@ -962,7 +968,7 @@ async def test_list_indexes_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListIndexesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -971,19 +977,15 @@ async def test_list_indexes_async_from_dict(): def test_list_indexes_field_headers(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.ListIndexesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: call.return_value = index_service.ListIndexesResponse() client.list_indexes(request) @@ -995,28 +997,23 @@ def test_list_indexes_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_indexes_field_headers_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.ListIndexesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_service.ListIndexesResponse()) + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_service.ListIndexesResponse() + ) await client.list_indexes(request) @@ -1027,138 +1024,98 @@ async def test_list_indexes_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_indexes_flattened(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = index_service.ListIndexesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_indexes( - parent='parent_value', - ) + client.list_indexes(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_indexes_flattened_error(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_indexes( - index_service.ListIndexesRequest(), - parent='parent_value', + index_service.ListIndexesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_indexes_flattened_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = index_service.ListIndexesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(index_service.ListIndexesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + index_service.ListIndexesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_indexes( - parent='parent_value', - ) + response = await client.list_indexes(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_indexes_flattened_error_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_indexes( - index_service.ListIndexesRequest(), - parent='parent_value', + index_service.ListIndexesRequest(), parent="parent_value", ) def test_list_indexes_pager(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - index.Index(), - ], - next_page_token='abc', - ), - index_service.ListIndexesResponse( - indexes=[], - next_page_token='def', - ), - index_service.ListIndexesResponse( - indexes=[ - index.Index(), - ], - next_page_token='ghi', + indexes=[index.Index(), index.Index(), index.Index(),], + next_page_token="abc", ), + index_service.ListIndexesResponse(indexes=[], next_page_token="def",), index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - ], + indexes=[index.Index(),], next_page_token="ghi", ), + index_service.ListIndexesResponse(indexes=[index.Index(), index.Index(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_indexes(request={}) @@ -1166,147 +1123,96 @@ def test_list_indexes_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, index.Index) - for i in results) + assert all(isinstance(i, index.Index) for i in results) + def test_list_indexes_pages(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_indexes), - '__call__') as call: + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - index.Index(), - ], - next_page_token='abc', - ), - index_service.ListIndexesResponse( - indexes=[], - next_page_token='def', - ), - index_service.ListIndexesResponse( - indexes=[ - index.Index(), - ], - next_page_token='ghi', + indexes=[index.Index(), index.Index(), index.Index(),], + next_page_token="abc", ), + index_service.ListIndexesResponse(indexes=[], next_page_token="def",), index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - ], + indexes=[index.Index(),], next_page_token="ghi", ), + index_service.ListIndexesResponse(indexes=[index.Index(), index.Index(),],), RuntimeError, ) pages = list(client.list_indexes(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_indexes_async_pager(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_indexes), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_indexes), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - index.Index(), - ], - next_page_token='abc', - ), - index_service.ListIndexesResponse( - indexes=[], - next_page_token='def', - ), - index_service.ListIndexesResponse( - indexes=[ - index.Index(), - ], - next_page_token='ghi', + indexes=[index.Index(), index.Index(), index.Index(),], + next_page_token="abc", ), + index_service.ListIndexesResponse(indexes=[], next_page_token="def",), index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - ], + indexes=[index.Index(),], next_page_token="ghi", ), + index_service.ListIndexesResponse(indexes=[index.Index(), index.Index(),],), RuntimeError, ) async_pager = await client.list_indexes(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, index.Index) - for i in responses) + assert all(isinstance(i, index.Index) for i in responses) + @pytest.mark.asyncio async def test_list_indexes_async_pages(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_indexes), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_indexes), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - index.Index(), - ], - next_page_token='abc', - ), - index_service.ListIndexesResponse( - indexes=[], - next_page_token='def', - ), - index_service.ListIndexesResponse( - indexes=[ - index.Index(), - ], - next_page_token='ghi', + indexes=[index.Index(), index.Index(), index.Index(),], + next_page_token="abc", ), + index_service.ListIndexesResponse(indexes=[], next_page_token="def",), index_service.ListIndexesResponse( - indexes=[ - index.Index(), - index.Index(), - ], + indexes=[index.Index(),], next_page_token="ghi", ), + index_service.ListIndexesResponse(indexes=[index.Index(), index.Index(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_indexes(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_index(transport: str = 'grpc', request_type=index_service.UpdateIndexRequest): +def test_update_index( + transport: str = "grpc", request_type=index_service.UpdateIndexRequest +): client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1314,11 +1220,9 @@ def test_update_index(transport: str = 'grpc', request_type=index_service.Update request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_index), - '__call__') as call: + with mock.patch.object(type(client.transport.update_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.update_index(request) @@ -1340,25 +1244,24 @@ def test_update_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_index), - '__call__') as call: + with mock.patch.object(type(client.transport.update_index), "__call__") as call: client.update_index() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_service.UpdateIndexRequest() + @pytest.mark.asyncio -async def test_update_index_async(transport: str = 'grpc_asyncio', request_type=index_service.UpdateIndexRequest): +async def test_update_index_async( + transport: str = "grpc_asyncio", request_type=index_service.UpdateIndexRequest +): client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1366,12 +1269,10 @@ async def test_update_index_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_index), - '__call__') as call: + with mock.patch.object(type(client.transport.update_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.update_index(request) @@ -1392,20 +1293,16 @@ async def test_update_index_async_from_dict(): def test_update_index_field_headers(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.UpdateIndexRequest() - request.index.name = 'index.name/value' + request.index.name = "index.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_index), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.update_index), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.update_index(request) @@ -1416,28 +1313,23 @@ def test_update_index_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'index.name=index.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "index.name=index.name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_update_index_field_headers_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.UpdateIndexRequest() - request.index.name = 'index.name/value' + request.index.name = "index.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_index), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.update_index), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.update_index(request) @@ -1448,29 +1340,22 @@ async def test_update_index_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'index.name=index.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "index.name=index.name/value",) in kw["metadata"] def test_update_index_flattened(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_index), - '__call__') as call: + with mock.patch.object(type(client.transport.update_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_index( - index=gca_index.Index(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index=gca_index.Index(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1478,47 +1363,41 @@ def test_update_index_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].index == gca_index.Index(name='name_value') + assert args[0].index == gca_index.Index(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_index_flattened_error(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_index( index_service.UpdateIndexRequest(), - index=gca_index.Index(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index=gca_index.Index(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_index_flattened_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_index), - '__call__') as call: + with mock.patch.object(type(client.transport.update_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_index( - index=gca_index.Index(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index=gca_index.Index(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1526,31 +1405,30 @@ async def test_update_index_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].index == gca_index.Index(name='name_value') + assert args[0].index == gca_index.Index(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_index_flattened_error_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_index( index_service.UpdateIndexRequest(), - index=gca_index.Index(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + index=gca_index.Index(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_index(transport: str = 'grpc', request_type=index_service.DeleteIndexRequest): +def test_delete_index( + transport: str = "grpc", request_type=index_service.DeleteIndexRequest +): client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1558,11 +1436,9 @@ def test_delete_index(transport: str = 'grpc', request_type=index_service.Delete request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_index), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_index(request) @@ -1584,25 +1460,24 @@ def test_delete_index_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_index), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: client.delete_index() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_service.DeleteIndexRequest() + @pytest.mark.asyncio -async def test_delete_index_async(transport: str = 'grpc_asyncio', request_type=index_service.DeleteIndexRequest): +async def test_delete_index_async( + transport: str = "grpc_asyncio", request_type=index_service.DeleteIndexRequest +): client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1610,12 +1485,10 @@ async def test_delete_index_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_index), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_index(request) @@ -1636,20 +1509,16 @@ async def test_delete_index_async_from_dict(): def test_delete_index_field_headers(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.DeleteIndexRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_index), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_index(request) @@ -1660,28 +1529,23 @@ def test_delete_index_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_index_field_headers_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = index_service.DeleteIndexRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_index), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_index(request) @@ -1692,94 +1556,73 @@ async def test_delete_index_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_index_flattened(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_index), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_index( - name='name_value', - ) + client.delete_index(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_index_flattened_error(): - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_index( - index_service.DeleteIndexRequest(), - name='name_value', + index_service.DeleteIndexRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_index_flattened_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_index), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_index( - name='name_value', - ) + response = await client.delete_index(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_index_flattened_error_async(): - client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = IndexServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_index( - index_service.DeleteIndexRequest(), - name='name_value', + index_service.DeleteIndexRequest(), name="name_value", ) @@ -1790,8 +1633,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1810,8 +1652,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = IndexServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1839,13 +1680,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.IndexServiceGrpcTransport, - transports.IndexServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.IndexServiceGrpcTransport, + transports.IndexServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1853,13 +1697,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.IndexServiceGrpcTransport, - ) + client = IndexServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.IndexServiceGrpcTransport,) def test_index_service_base_transport_error(): @@ -1867,13 +1706,15 @@ def test_index_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.IndexServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_index_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.IndexServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1882,12 +1723,12 @@ def test_index_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_index', - 'get_index', - 'list_indexes', - 'update_index', - 'delete_index', - ) + "create_index", + "get_index", + "list_indexes", + "update_index", + "delete_index", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1900,23 +1741,28 @@ def test_index_service_base_transport(): def test_index_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.IndexServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_index_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.index_service.transports.IndexServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.IndexServiceTransport() @@ -1925,11 +1771,11 @@ def test_index_service_base_transport_with_adc(): def test_index_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) IndexServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1937,19 +1783,22 @@ def test_index_service_auth_adc(): def test_index_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.IndexServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.IndexServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport]) -def test_index_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport], +) +def test_index_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1958,15 +1807,13 @@ def test_index_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1981,38 +1828,40 @@ def test_index_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_index_service_host_no_port(): client = IndexServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_index_service_host_with_port(): client = IndexServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_index_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.IndexServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2020,12 +1869,11 @@ def test_index_service_grpc_transport_channel(): def test_index_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.IndexServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2034,12 +1882,17 @@ def test_index_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport]) -def test_index_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport], +) +def test_index_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2048,7 +1901,7 @@ def test_index_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2064,9 +1917,7 @@ def test_index_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2080,17 +1931,20 @@ def test_index_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport]) -def test_index_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport], +) +def test_index_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2107,9 +1961,7 @@ def test_index_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2122,16 +1974,12 @@ def test_index_service_transport_channel_mtls_with_adc( def test_index_service_grpc_lro_client(): client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2139,16 +1987,12 @@ def test_index_service_grpc_lro_client(): def test_index_service_grpc_lro_async_client(): client = IndexServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2159,17 +2003,18 @@ def test_index_path(): location = "clam" index = "whelk" - expected = "projects/{project}/locations/{location}/indexes/{index}".format(project=project, location=location, index=index, ) + expected = "projects/{project}/locations/{location}/indexes/{index}".format( + project=project, location=location, index=index, + ) actual = IndexServiceClient.index_path(project, location, index) assert expected == actual def test_parse_index_path(): expected = { - "project": "octopus", - "location": "oyster", - "index": "nudibranch", - + "project": "octopus", + "location": "oyster", + "index": "nudibranch", } path = IndexServiceClient.index_path(**expected) @@ -2177,22 +2022,24 @@ def test_parse_index_path(): actual = IndexServiceClient.parse_index_path(path) assert expected == actual + def test_index_endpoint_path(): project = "cuttlefish" location = "mussel" index_endpoint = "winkle" - expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format(project=project, location=location, index_endpoint=index_endpoint, ) + expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( + project=project, location=location, index_endpoint=index_endpoint, + ) actual = IndexServiceClient.index_endpoint_path(project, location, index_endpoint) assert expected == actual def test_parse_index_endpoint_path(): expected = { - "project": "nautilus", - "location": "scallop", - "index_endpoint": "abalone", - + "project": "nautilus", + "location": "scallop", + "index_endpoint": "abalone", } path = IndexServiceClient.index_endpoint_path(**expected) @@ -2200,18 +2047,20 @@ def test_parse_index_endpoint_path(): actual = IndexServiceClient.parse_index_endpoint_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = IndexServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = IndexServiceClient.common_billing_account_path(**expected) @@ -2219,18 +2068,18 @@ def test_parse_common_billing_account_path(): actual = IndexServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = IndexServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = IndexServiceClient.common_folder_path(**expected) @@ -2238,18 +2087,18 @@ def test_parse_common_folder_path(): actual = IndexServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = IndexServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = IndexServiceClient.common_organization_path(**expected) @@ -2257,18 +2106,18 @@ def test_parse_common_organization_path(): actual = IndexServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = IndexServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = IndexServiceClient.common_project_path(**expected) @@ -2276,20 +2125,22 @@ def test_parse_common_project_path(): actual = IndexServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = IndexServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = IndexServiceClient.common_location_path(**expected) @@ -2301,17 +2152,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.IndexServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.IndexServiceTransport, "_prep_wrapped_messages" + ) as prep: client = IndexServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.IndexServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.IndexServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = IndexServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index 35b2de66b5..b9c944280d 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.metadata_service import MetadataServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.metadata_service import MetadataServiceClient +from google.cloud.aiplatform_v1beta1.services.metadata_service import ( + MetadataServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.metadata_service import ( + MetadataServiceClient, +) from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers from google.cloud.aiplatform_v1beta1.services.metadata_service import transports from google.cloud.aiplatform_v1beta1.types import artifact @@ -69,7 +73,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -80,36 +88,52 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MetadataServiceClient._get_default_mtls_endpoint(None) is None - assert MetadataServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MetadataServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MetadataServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MetadataServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - MetadataServiceClient, - MetadataServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MetadataServiceClient, MetadataServiceAsyncClient,] +) def test_metadata_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - MetadataServiceClient, - MetadataServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MetadataServiceClient, MetadataServiceAsyncClient,] +) def test_metadata_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -119,7 +143,7 @@ def test_metadata_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_metadata_service_client_get_transport_class(): @@ -133,29 +157,44 @@ def test_metadata_service_client_get_transport_class(): assert transport == transports.MetadataServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(MetadataServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceClient)) -@mock.patch.object(MetadataServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceAsyncClient)) -def test_metadata_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MetadataServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceClient), +) +@mock.patch.object( + MetadataServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceAsyncClient), +) +def test_metadata_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MetadataServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MetadataServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MetadataServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MetadataServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -171,7 +210,7 @@ def test_metadata_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -187,7 +226,7 @@ def test_metadata_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -207,13 +246,15 @@ def test_metadata_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -226,26 +267,62 @@ def test_metadata_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc", "true"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc", "false"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(MetadataServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceClient)) -@mock.patch.object(MetadataServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MetadataServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MetadataServiceClient, + transports.MetadataServiceGrpcTransport, + "grpc", + "true", + ), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MetadataServiceClient, + transports.MetadataServiceGrpcTransport, + "grpc", + "false", + ), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MetadataServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceClient), +) +@mock.patch.object( + MetadataServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MetadataServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_metadata_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_metadata_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -268,10 +345,18 @@ def test_metadata_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -292,9 +377,14 @@ def test_metadata_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -308,16 +398,23 @@ def test_metadata_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_metadata_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_metadata_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -330,16 +427,24 @@ def test_metadata_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), - (MetadataServiceAsyncClient, transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_metadata_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MetadataServiceClient, transports.MetadataServiceGrpcTransport, "grpc"), + ( + MetadataServiceAsyncClient, + transports.MetadataServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_metadata_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -354,10 +459,12 @@ def test_metadata_service_client_client_options_credentials_file(client_class, t def test_metadata_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MetadataServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -370,10 +477,11 @@ def test_metadata_service_client_client_options_from_dict(): ) -def test_create_metadata_store(transport: str = 'grpc', request_type=metadata_service.CreateMetadataStoreRequest): +def test_create_metadata_store( + transport: str = "grpc", request_type=metadata_service.CreateMetadataStoreRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -382,10 +490,10 @@ def test_create_metadata_store(transport: str = 'grpc', request_type=metadata_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_metadata_store(request) @@ -407,25 +515,27 @@ def test_create_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: client.create_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateMetadataStoreRequest() + @pytest.mark.asyncio -async def test_create_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateMetadataStoreRequest): +async def test_create_metadata_store_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.CreateMetadataStoreRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -434,11 +544,11 @@ async def test_create_metadata_store_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_metadata_store(request) @@ -459,20 +569,18 @@ async def test_create_metadata_store_async_from_dict(): def test_create_metadata_store_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataStoreRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_metadata_store), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_metadata_store(request) @@ -483,28 +591,25 @@ def test_create_metadata_store_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_metadata_store_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataStoreRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_metadata_store), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_metadata_store(request) @@ -515,30 +620,25 @@ async def test_create_metadata_store_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_metadata_store_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_metadata_store( - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) # Establish that the underlying call was made with the expected @@ -546,51 +646,49 @@ def test_create_metadata_store_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_store == gca_metadata_store.MetadataStore(name='name_value') + assert args[0].metadata_store == gca_metadata_store.MetadataStore( + name="name_value" + ) - assert args[0].metadata_store_id == 'metadata_store_id_value' + assert args[0].metadata_store_id == "metadata_store_id_value" def test_create_metadata_store_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_metadata_store( metadata_service.CreateMetadataStoreRequest(), - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) @pytest.mark.asyncio async def test_create_metadata_store_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_store), - '__call__') as call: + type(client.transport.create_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_metadata_store( - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) # Establish that the underlying call was made with the expected @@ -598,34 +696,35 @@ async def test_create_metadata_store_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_store == gca_metadata_store.MetadataStore(name='name_value') + assert args[0].metadata_store == gca_metadata_store.MetadataStore( + name="name_value" + ) - assert args[0].metadata_store_id == 'metadata_store_id_value' + assert args[0].metadata_store_id == "metadata_store_id_value" @pytest.mark.asyncio async def test_create_metadata_store_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_metadata_store( metadata_service.CreateMetadataStoreRequest(), - parent='parent_value', - metadata_store=gca_metadata_store.MetadataStore(name='name_value'), - metadata_store_id='metadata_store_id_value', + parent="parent_value", + metadata_store=gca_metadata_store.MetadataStore(name="name_value"), + metadata_store_id="metadata_store_id_value", ) -def test_get_metadata_store(transport: str = 'grpc', request_type=metadata_service.GetMetadataStoreRequest): +def test_get_metadata_store( + transport: str = "grpc", request_type=metadata_service.GetMetadataStoreRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -634,13 +733,10 @@ def test_get_metadata_store(transport: str = 'grpc', request_type=metadata_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_store.MetadataStore( - name='name_value', - - ) + call.return_value = metadata_store.MetadataStore(name="name_value",) response = client.get_metadata_store(request) @@ -654,7 +750,7 @@ def test_get_metadata_store(transport: str = 'grpc', request_type=metadata_servi assert isinstance(response, metadata_store.MetadataStore) - assert response.name == 'name_value' + assert response.name == "name_value" def test_get_metadata_store_from_dict(): @@ -665,25 +761,27 @@ def test_get_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: client.get_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetMetadataStoreRequest() + @pytest.mark.asyncio -async def test_get_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetMetadataStoreRequest): +async def test_get_metadata_store_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.GetMetadataStoreRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -692,12 +790,12 @@ async def test_get_metadata_store_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore( - name='name_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_store.MetadataStore(name="name_value",) + ) response = await client.get_metadata_store(request) @@ -710,7 +808,7 @@ async def test_get_metadata_store_async(transport: str = 'grpc_asyncio', request # Establish that the response is the type that we expect. assert isinstance(response, metadata_store.MetadataStore) - assert response.name == 'name_value' + assert response.name == "name_value" @pytest.mark.asyncio @@ -719,19 +817,17 @@ async def test_get_metadata_store_async_from_dict(): def test_get_metadata_store_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: call.return_value = metadata_store.MetadataStore() client.get_metadata_store(request) @@ -743,28 +839,25 @@ def test_get_metadata_store_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_metadata_store_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore()) + type(client.transport.get_metadata_store), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_store.MetadataStore() + ) await client.get_metadata_store(request) @@ -775,99 +868,85 @@ async def test_get_metadata_store_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_metadata_store_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_store.MetadataStore() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_metadata_store( - name='name_value', - ) + client.get_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_metadata_store_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_metadata_store( - metadata_service.GetMetadataStoreRequest(), - name='name_value', + metadata_service.GetMetadataStoreRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_metadata_store_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_store), - '__call__') as call: + type(client.transport.get_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_store.MetadataStore() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_store.MetadataStore()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_store.MetadataStore() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_metadata_store( - name='name_value', - ) + response = await client.get_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_metadata_store_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_metadata_store( - metadata_service.GetMetadataStoreRequest(), - name='name_value', + metadata_service.GetMetadataStoreRequest(), name="name_value", ) -def test_list_metadata_stores(transport: str = 'grpc', request_type=metadata_service.ListMetadataStoresRequest): +def test_list_metadata_stores( + transport: str = "grpc", request_type=metadata_service.ListMetadataStoresRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -876,12 +955,11 @@ def test_list_metadata_stores(transport: str = 'grpc', request_type=metadata_ser # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataStoresResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_metadata_stores(request) @@ -896,7 +974,7 @@ def test_list_metadata_stores(transport: str = 'grpc', request_type=metadata_ser assert isinstance(response, pagers.ListMetadataStoresPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_metadata_stores_from_dict(): @@ -907,25 +985,27 @@ def test_list_metadata_stores_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: client.list_metadata_stores() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListMetadataStoresRequest() + @pytest.mark.asyncio -async def test_list_metadata_stores_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListMetadataStoresRequest): +async def test_list_metadata_stores_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.ListMetadataStoresRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -934,12 +1014,14 @@ async def test_list_metadata_stores_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataStoresResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_metadata_stores(request) @@ -952,7 +1034,7 @@ async def test_list_metadata_stores_async(transport: str = 'grpc_asyncio', reque # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListMetadataStoresAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -961,19 +1043,17 @@ async def test_list_metadata_stores_async_from_dict(): def test_list_metadata_stores_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataStoresRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: call.return_value = metadata_service.ListMetadataStoresResponse() client.list_metadata_stores(request) @@ -985,28 +1065,25 @@ def test_list_metadata_stores_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_metadata_stores_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataStoresRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse()) + type(client.transport.list_metadata_stores), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataStoresResponse() + ) await client.list_metadata_stores(request) @@ -1017,104 +1094,87 @@ async def test_list_metadata_stores_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_metadata_stores_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataStoresResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_metadata_stores( - parent='parent_value', - ) + client.list_metadata_stores(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_metadata_stores_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_metadata_stores( - metadata_service.ListMetadataStoresRequest(), - parent='parent_value', + metadata_service.ListMetadataStoresRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_metadata_stores_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataStoresResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataStoresResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataStoresResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_metadata_stores( - parent='parent_value', - ) + response = await client.list_metadata_stores(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_metadata_stores_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_metadata_stores( - metadata_service.ListMetadataStoresRequest(), - parent='parent_value', + metadata_service.ListMetadataStoresRequest(), parent="parent_value", ) def test_list_metadata_stores_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1123,17 +1183,14 @@ def test_list_metadata_stores_pager(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1146,9 +1203,7 @@ def test_list_metadata_stores_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_metadata_stores(request={}) @@ -1156,18 +1211,16 @@ def test_list_metadata_stores_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, metadata_store.MetadataStore) - for i in results) + assert all(isinstance(i, metadata_store.MetadataStore) for i in results) + def test_list_metadata_stores_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__') as call: + type(client.transport.list_metadata_stores), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1176,17 +1229,14 @@ def test_list_metadata_stores_pages(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1197,19 +1247,20 @@ def test_list_metadata_stores_pages(): RuntimeError, ) pages = list(client.list_metadata_stores(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_metadata_stores_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_stores), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1218,17 +1269,14 @@ async def test_list_metadata_stores_async_pager(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1239,25 +1287,25 @@ async def test_list_metadata_stores_async_pager(): RuntimeError, ) async_pager = await client.list_metadata_stores(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, metadata_store.MetadataStore) - for i in responses) + assert all(isinstance(i, metadata_store.MetadataStore) for i in responses) + @pytest.mark.asyncio async def test_list_metadata_stores_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_stores), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_stores), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataStoresResponse( @@ -1266,17 +1314,14 @@ async def test_list_metadata_stores_async_pages(): metadata_store.MetadataStore(), metadata_store.MetadataStore(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[], - next_page_token='def', + metadata_stores=[], next_page_token="def", ), metadata_service.ListMetadataStoresResponse( - metadata_stores=[ - metadata_store.MetadataStore(), - ], - next_page_token='ghi', + metadata_stores=[metadata_store.MetadataStore(),], + next_page_token="ghi", ), metadata_service.ListMetadataStoresResponse( metadata_stores=[ @@ -1289,14 +1334,15 @@ async def test_list_metadata_stores_async_pages(): pages = [] async for page_ in (await client.list_metadata_stores(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_metadata_store(transport: str = 'grpc', request_type=metadata_service.DeleteMetadataStoreRequest): +def test_delete_metadata_store( + transport: str = "grpc", request_type=metadata_service.DeleteMetadataStoreRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1305,10 +1351,10 @@ def test_delete_metadata_store(transport: str = 'grpc', request_type=metadata_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_metadata_store(request) @@ -1330,25 +1376,27 @@ def test_delete_metadata_store_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: client.delete_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.DeleteMetadataStoreRequest() + @pytest.mark.asyncio -async def test_delete_metadata_store_async(transport: str = 'grpc_asyncio', request_type=metadata_service.DeleteMetadataStoreRequest): +async def test_delete_metadata_store_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.DeleteMetadataStoreRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1357,11 +1405,11 @@ async def test_delete_metadata_store_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_metadata_store(request) @@ -1382,20 +1430,18 @@ async def test_delete_metadata_store_async_from_dict(): def test_delete_metadata_store_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_metadata_store), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_metadata_store(request) @@ -1406,28 +1452,25 @@ def test_delete_metadata_store_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_metadata_store_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteMetadataStoreRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_metadata_store), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_metadata_store(request) @@ -1438,101 +1481,85 @@ async def test_delete_metadata_store_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_metadata_store_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_metadata_store( - name='name_value', - ) + client.delete_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_metadata_store_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_metadata_store( - metadata_service.DeleteMetadataStoreRequest(), - name='name_value', + metadata_service.DeleteMetadataStoreRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_metadata_store_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_metadata_store), - '__call__') as call: + type(client.transport.delete_metadata_store), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_metadata_store( - name='name_value', - ) + response = await client.delete_metadata_store(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_metadata_store_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_metadata_store( - metadata_service.DeleteMetadataStoreRequest(), - name='name_value', + metadata_service.DeleteMetadataStoreRequest(), name="name_value", ) -def test_create_artifact(transport: str = 'grpc', request_type=metadata_service.CreateArtifactRequest): +def test_create_artifact( + transport: str = "grpc", request_type=metadata_service.CreateArtifactRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1540,27 +1567,17 @@ def test_create_artifact(transport: str = 'grpc', request_type=metadata_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact( - name='name_value', - - display_name='display_name_value', - - uri='uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", state=gca_artifact.Artifact.State.PENDING, - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.create_artifact(request) @@ -1575,21 +1592,21 @@ def test_create_artifact(transport: str = 'grpc', request_type=metadata_service. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_artifact_from_dict(): @@ -1600,25 +1617,24 @@ def test_create_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: client.create_artifact() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateArtifactRequest() + @pytest.mark.asyncio -async def test_create_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateArtifactRequest): +async def test_create_artifact_async( + transport: str = "grpc_asyncio", request_type=metadata_service.CreateArtifactRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1626,20 +1642,20 @@ async def test_create_artifact_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact( - name='name_value', - display_name='display_name_value', - uri='uri_value', - etag='etag_value', - state=gca_artifact.Artifact.State.PENDING, - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact( + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=gca_artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.create_artifact(request) @@ -1652,21 +1668,21 @@ async def test_create_artifact_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -1675,19 +1691,15 @@ async def test_create_artifact_async_from_dict(): def test_create_artifact_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateArtifactRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: call.return_value = gca_artifact.Artifact() client.create_artifact(request) @@ -1699,28 +1711,23 @@ def test_create_artifact_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_artifact_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateArtifactRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) await client.create_artifact(request) @@ -1731,30 +1738,23 @@ async def test_create_artifact_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_artifact_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_artifact( - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) # Establish that the underlying call was made with the expected @@ -1762,49 +1762,45 @@ def test_create_artifact_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].artifact_id == 'artifact_id_value' + assert args[0].artifact_id == "artifact_id_value" def test_create_artifact_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_artifact( metadata_service.CreateArtifactRequest(), - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) @pytest.mark.asyncio async def test_create_artifact_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_artifact( - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) # Establish that the underlying call was made with the expected @@ -1812,34 +1808,33 @@ async def test_create_artifact_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].artifact_id == 'artifact_id_value' + assert args[0].artifact_id == "artifact_id_value" @pytest.mark.asyncio async def test_create_artifact_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_artifact( metadata_service.CreateArtifactRequest(), - parent='parent_value', - artifact=gca_artifact.Artifact(name='name_value'), - artifact_id='artifact_id_value', + parent="parent_value", + artifact=gca_artifact.Artifact(name="name_value"), + artifact_id="artifact_id_value", ) -def test_get_artifact(transport: str = 'grpc', request_type=metadata_service.GetArtifactRequest): +def test_get_artifact( + transport: str = "grpc", request_type=metadata_service.GetArtifactRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1847,56 +1842,46 @@ def test_get_artifact(transport: str = 'grpc', request_type=metadata_service.Get request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = artifact.Artifact( - name='name_value', + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) - display_name='display_name_value', + response = client.get_artifact(request) - uri='uri_value', + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] - etag='etag_value', - - state=artifact.Artifact.State.PENDING, - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - - ) - - response = client.get_artifact(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == metadata_service.GetArtifactRequest() + assert args[0] == metadata_service.GetArtifactRequest() # Establish that the response is the type that we expect. assert isinstance(response, artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_artifact_from_dict(): @@ -1907,25 +1892,24 @@ def test_get_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: client.get_artifact() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetArtifactRequest() + @pytest.mark.asyncio -async def test_get_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetArtifactRequest): +async def test_get_artifact_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetArtifactRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1933,20 +1917,20 @@ async def test_get_artifact_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact( - name='name_value', - display_name='display_name_value', - uri='uri_value', - etag='etag_value', - state=artifact.Artifact.State.PENDING, - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + artifact.Artifact( + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.get_artifact(request) @@ -1959,21 +1943,21 @@ async def test_get_artifact_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -1982,19 +1966,15 @@ async def test_get_artifact_async_from_dict(): def test_get_artifact_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetArtifactRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: call.return_value = artifact.Artifact() client.get_artifact(request) @@ -2006,27 +1986,20 @@ def test_get_artifact_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_artifact_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetArtifactRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact()) await client.get_artifact(request) @@ -2038,99 +2011,79 @@ async def test_get_artifact_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_artifact_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = artifact.Artifact() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_artifact( - name='name_value', - ) + client.get_artifact(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_artifact_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_artifact( - metadata_service.GetArtifactRequest(), - name='name_value', + metadata_service.GetArtifactRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_artifact_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = artifact.Artifact() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(artifact.Artifact()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_artifact( - name='name_value', - ) + response = await client.get_artifact(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_artifact_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_artifact( - metadata_service.GetArtifactRequest(), - name='name_value', + metadata_service.GetArtifactRequest(), name="name_value", ) -def test_list_artifacts(transport: str = 'grpc', request_type=metadata_service.ListArtifactsRequest): +def test_list_artifacts( + transport: str = "grpc", request_type=metadata_service.ListArtifactsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2138,13 +2091,10 @@ def test_list_artifacts(transport: str = 'grpc', request_type=metadata_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListArtifactsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_artifacts(request) @@ -2159,7 +2109,7 @@ def test_list_artifacts(transport: str = 'grpc', request_type=metadata_service.L assert isinstance(response, pagers.ListArtifactsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_artifacts_from_dict(): @@ -2170,25 +2120,24 @@ def test_list_artifacts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: client.list_artifacts() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListArtifactsRequest() + @pytest.mark.asyncio -async def test_list_artifacts_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListArtifactsRequest): +async def test_list_artifacts_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListArtifactsRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2196,13 +2145,13 @@ async def test_list_artifacts_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListArtifactsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_artifacts(request) @@ -2215,7 +2164,7 @@ async def test_list_artifacts_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListArtifactsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2224,19 +2173,15 @@ async def test_list_artifacts_async_from_dict(): def test_list_artifacts_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListArtifactsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: call.return_value = metadata_service.ListArtifactsResponse() client.list_artifacts(request) @@ -2248,28 +2193,23 @@ def test_list_artifacts_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_artifacts_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListArtifactsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse()) + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListArtifactsResponse() + ) await client.list_artifacts(request) @@ -2280,104 +2220,81 @@ async def test_list_artifacts_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_artifacts_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListArtifactsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_artifacts( - parent='parent_value', - ) + client.list_artifacts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_artifacts_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_artifacts( - metadata_service.ListArtifactsRequest(), - parent='parent_value', + metadata_service.ListArtifactsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_artifacts_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListArtifactsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListArtifactsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListArtifactsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_artifacts( - parent='parent_value', - ) + response = await client.list_artifacts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_artifacts_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_artifacts( - metadata_service.ListArtifactsRequest(), - parent='parent_value', + metadata_service.ListArtifactsRequest(), parent="parent_value", ) def test_list_artifacts_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2386,32 +2303,23 @@ def test_list_artifacts_pager(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_artifacts(request={}) @@ -2419,18 +2327,14 @@ def test_list_artifacts_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, artifact.Artifact) - for i in results) + assert all(isinstance(i, artifact.Artifact) for i in results) + def test_list_artifacts_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_artifacts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2439,40 +2343,32 @@ def test_list_artifacts_pages(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) pages = list(client.list_artifacts(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_artifacts_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_artifacts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_artifacts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2481,46 +2377,37 @@ async def test_list_artifacts_async_pager(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) async_pager = await client.list_artifacts(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, artifact.Artifact) - for i in responses) + assert all(isinstance(i, artifact.Artifact) for i in responses) + @pytest.mark.asyncio async def test_list_artifacts_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_artifacts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_artifacts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListArtifactsResponse( @@ -2529,37 +2416,31 @@ async def test_list_artifacts_async_pages(): artifact.Artifact(), artifact.Artifact(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListArtifactsResponse( - artifacts=[], - next_page_token='def', + artifacts=[], next_page_token="def", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - ], - next_page_token='ghi', + artifacts=[artifact.Artifact(),], next_page_token="ghi", ), metadata_service.ListArtifactsResponse( - artifacts=[ - artifact.Artifact(), - artifact.Artifact(), - ], + artifacts=[artifact.Artifact(), artifact.Artifact(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_artifacts(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_artifact(transport: str = 'grpc', request_type=metadata_service.UpdateArtifactRequest): +def test_update_artifact( + transport: str = "grpc", request_type=metadata_service.UpdateArtifactRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2567,27 +2448,17 @@ def test_update_artifact(transport: str = 'grpc', request_type=metadata_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact( - name='name_value', - - display_name='display_name_value', - - uri='uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", state=gca_artifact.Artifact.State.PENDING, - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.update_artifact(request) @@ -2602,21 +2473,21 @@ def test_update_artifact(transport: str = 'grpc', request_type=metadata_service. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_update_artifact_from_dict(): @@ -2627,25 +2498,24 @@ def test_update_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: client.update_artifact() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateArtifactRequest() + @pytest.mark.asyncio -async def test_update_artifact_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateArtifactRequest): +async def test_update_artifact_async( + transport: str = "grpc_asyncio", request_type=metadata_service.UpdateArtifactRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2653,20 +2523,20 @@ async def test_update_artifact_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact( - name='name_value', - display_name='display_name_value', - uri='uri_value', - etag='etag_value', - state=gca_artifact.Artifact.State.PENDING, - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact( + name="name_value", + display_name="display_name_value", + uri="uri_value", + etag="etag_value", + state=gca_artifact.Artifact.State.PENDING, + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.update_artifact(request) @@ -2679,21 +2549,21 @@ async def test_update_artifact_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, gca_artifact.Artifact) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.uri == 'uri_value' + assert response.uri == "uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" assert response.state == gca_artifact.Artifact.State.PENDING - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -2702,19 +2572,15 @@ async def test_update_artifact_async_from_dict(): def test_update_artifact_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateArtifactRequest() - request.artifact.name = 'artifact.name/value' + request.artifact.name = "artifact.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: call.return_value = gca_artifact.Artifact() client.update_artifact(request) @@ -2726,28 +2592,25 @@ def test_update_artifact_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'artifact.name=artifact.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "artifact.name=artifact.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_artifact_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateArtifactRequest() - request.artifact.name = 'artifact.name/value' + request.artifact.name = "artifact.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) await client.update_artifact(request) @@ -2758,29 +2621,24 @@ async def test_update_artifact_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'artifact.name=artifact.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "artifact.name=artifact.name/value",) in kw[ + "metadata" + ] def test_update_artifact_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_artifact( - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2788,45 +2646,41 @@ def test_update_artifact_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_artifact_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_artifact( metadata_service.UpdateArtifactRequest(), - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_artifact_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_artifact), - '__call__') as call: + with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_artifact.Artifact() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_artifact.Artifact()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_artifact.Artifact() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_artifact( - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2834,31 +2688,30 @@ async def test_update_artifact_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].artifact == gca_artifact.Artifact(name='name_value') + assert args[0].artifact == gca_artifact.Artifact(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_artifact_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_artifact( metadata_service.UpdateArtifactRequest(), - artifact=gca_artifact.Artifact(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + artifact=gca_artifact.Artifact(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_create_context(transport: str = 'grpc', request_type=metadata_service.CreateContextRequest): +def test_create_context( + transport: str = "grpc", request_type=metadata_service.CreateContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2866,25 +2719,16 @@ def test_create_context(transport: str = 'grpc', request_type=metadata_service.C request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - - parent_contexts=['parent_contexts_value'], - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.create_context(request) @@ -2899,19 +2743,19 @@ def test_create_context(transport: str = 'grpc', request_type=metadata_service.C assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_context_from_dict(): @@ -2922,25 +2766,24 @@ def test_create_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: client.create_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateContextRequest() + @pytest.mark.asyncio -async def test_create_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateContextRequest): +async def test_create_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.CreateContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2948,19 +2791,19 @@ async def test_create_context_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context( - name='name_value', - display_name='display_name_value', - etag='etag_value', - parent_contexts=['parent_contexts_value'], - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.create_context(request) @@ -2973,19 +2816,19 @@ async def test_create_context_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -2994,19 +2837,15 @@ async def test_create_context_async_from_dict(): def test_create_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateContextRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: call.return_value = gca_context.Context() client.create_context(request) @@ -3018,27 +2857,20 @@ def test_create_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateContextRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) await client.create_context(request) @@ -3050,30 +2882,23 @@ async def test_create_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_context( - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) # Establish that the underlying call was made with the expected @@ -3081,39 +2906,33 @@ def test_create_context_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].context_id == 'context_id_value' + assert args[0].context_id == "context_id_value" def test_create_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_context( metadata_service.CreateContextRequest(), - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) @pytest.mark.asyncio async def test_create_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_context), - '__call__') as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() @@ -3121,9 +2940,9 @@ async def test_create_context_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_context( - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) # Establish that the underlying call was made with the expected @@ -3131,34 +2950,33 @@ async def test_create_context_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].context_id == 'context_id_value' + assert args[0].context_id == "context_id_value" @pytest.mark.asyncio async def test_create_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_context( metadata_service.CreateContextRequest(), - parent='parent_value', - context=gca_context.Context(name='name_value'), - context_id='context_id_value', + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) -def test_get_context(transport: str = 'grpc', request_type=metadata_service.GetContextRequest): +def test_get_context( + transport: str = "grpc", request_type=metadata_service.GetContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3166,25 +2984,16 @@ def test_get_context(transport: str = 'grpc', request_type=metadata_service.GetC request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = context.Context( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - - parent_contexts=['parent_contexts_value'], - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.get_context(request) @@ -3199,19 +3008,19 @@ def test_get_context(transport: str = 'grpc', request_type=metadata_service.GetC assert isinstance(response, context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_context_from_dict(): @@ -3222,25 +3031,24 @@ def test_get_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: client.get_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetContextRequest() + @pytest.mark.asyncio -async def test_get_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetContextRequest): +async def test_get_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3248,19 +3056,19 @@ async def test_get_context_async(transport: str = 'grpc_asyncio', request_type=m request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context( - name='name_value', - display_name='display_name_value', - etag='etag_value', - parent_contexts=['parent_contexts_value'], - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.get_context(request) @@ -3273,19 +3081,19 @@ async def test_get_context_async(transport: str = 'grpc_asyncio', request_type=m # Establish that the response is the type that we expect. assert isinstance(response, context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -3294,19 +3102,15 @@ async def test_get_context_async_from_dict(): def test_get_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: call.return_value = context.Context() client.get_context(request) @@ -3318,27 +3122,20 @@ def test_get_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) await client.get_context(request) @@ -3350,99 +3147,79 @@ async def test_get_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_context( - name='name_value', - ) + client.get_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_context( - metadata_service.GetContextRequest(), - name='name_value', + metadata_service.GetContextRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_context), - '__call__') as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = context.Context() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_context( - name='name_value', - ) + response = await client.get_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_context( - metadata_service.GetContextRequest(), - name='name_value', + metadata_service.GetContextRequest(), name="name_value", ) -def test_list_contexts(transport: str = 'grpc', request_type=metadata_service.ListContextsRequest): +def test_list_contexts( + transport: str = "grpc", request_type=metadata_service.ListContextsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3450,13 +3227,10 @@ def test_list_contexts(transport: str = 'grpc', request_type=metadata_service.Li request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListContextsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_contexts(request) @@ -3471,7 +3245,7 @@ def test_list_contexts(transport: str = 'grpc', request_type=metadata_service.Li assert isinstance(response, pagers.ListContextsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_contexts_from_dict(): @@ -3482,25 +3256,24 @@ def test_list_contexts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: client.list_contexts() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListContextsRequest() + @pytest.mark.asyncio -async def test_list_contexts_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListContextsRequest): +async def test_list_contexts_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListContextsRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3508,13 +3281,13 @@ async def test_list_contexts_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_contexts(request) @@ -3527,7 +3300,7 @@ async def test_list_contexts_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListContextsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3536,19 +3309,15 @@ async def test_list_contexts_async_from_dict(): def test_list_contexts_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListContextsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: call.return_value = metadata_service.ListContextsResponse() client.list_contexts(request) @@ -3560,28 +3329,23 @@ def test_list_contexts_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_contexts_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListContextsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse()) + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse() + ) await client.list_contexts(request) @@ -3592,138 +3356,100 @@ async def test_list_contexts_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_contexts_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListContextsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_contexts( - parent='parent_value', - ) + client.list_contexts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_contexts_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_contexts( - metadata_service.ListContextsRequest(), - parent='parent_value', + metadata_service.ListContextsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_contexts_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListContextsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListContextsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_contexts( - parent='parent_value', - ) + response = await client.list_contexts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_contexts_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_contexts( - metadata_service.ListContextsRequest(), - parent='parent_value', + metadata_service.ListContextsRequest(), parent="parent_value", ) def test_list_contexts_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_contexts(request={}) @@ -3731,147 +3457,102 @@ def test_list_contexts_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, context.Context) - for i in results) + assert all(isinstance(i, context.Context) for i in results) + def test_list_contexts_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), - '__call__') as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) pages = list(client.list_contexts(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_contexts_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_contexts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) async_pager = await client.list_contexts(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, context.Context) - for i in responses) + assert all(isinstance(i, context.Context) for i in responses) + @pytest.mark.asyncio async def test_list_contexts_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_contexts), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - context.Context(), - ], - next_page_token='abc', - ), - metadata_service.ListContextsResponse( - contexts=[], - next_page_token='def', + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - ], - next_page_token='ghi', + contexts=[context.Context(),], next_page_token="ghi", ), metadata_service.ListContextsResponse( - contexts=[ - context.Context(), - context.Context(), - ], + contexts=[context.Context(), context.Context(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_contexts(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_context(transport: str = 'grpc', request_type=metadata_service.UpdateContextRequest): +def test_update_context( + transport: str = "grpc", request_type=metadata_service.UpdateContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3879,25 +3560,16 @@ def test_update_context(transport: str = 'grpc', request_type=metadata_service.U request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - - parent_contexts=['parent_contexts_value'], - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.update_context(request) @@ -3912,19 +3584,19 @@ def test_update_context(transport: str = 'grpc', request_type=metadata_service.U assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_update_context_from_dict(): @@ -3935,25 +3607,24 @@ def test_update_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: client.update_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateContextRequest() + @pytest.mark.asyncio -async def test_update_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateContextRequest): +async def test_update_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.UpdateContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3961,19 +3632,19 @@ async def test_update_context_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context( - name='name_value', - display_name='display_name_value', - etag='etag_value', - parent_contexts=['parent_contexts_value'], - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.update_context(request) @@ -3986,19 +3657,19 @@ async def test_update_context_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_context.Context) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.parent_contexts == ['parent_contexts_value'] + assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -4007,19 +3678,15 @@ async def test_update_context_async_from_dict(): def test_update_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateContextRequest() - request.context.name = 'context.name/value' + request.context.name = "context.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: call.return_value = gca_context.Context() client.update_context(request) @@ -4031,27 +3698,22 @@ def test_update_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context.name=context.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateContextRequest() - request.context.name = 'context.name/value' + request.context.name = "context.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) await client.update_context(request) @@ -4063,29 +3725,24 @@ async def test_update_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context.name=context.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ + "metadata" + ] def test_update_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_context( - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -4093,36 +3750,30 @@ def test_update_context_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_context( metadata_service.UpdateContextRequest(), - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_context), - '__call__') as call: + with mock.patch.object(type(client.transport.update_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_context.Context() @@ -4130,8 +3781,8 @@ async def test_update_context_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_context( - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -4139,31 +3790,30 @@ async def test_update_context_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == gca_context.Context(name='name_value') + assert args[0].context == gca_context.Context(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_context( metadata_service.UpdateContextRequest(), - context=gca_context.Context(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + context=gca_context.Context(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_context(transport: str = 'grpc', request_type=metadata_service.DeleteContextRequest): +def test_delete_context( + transport: str = "grpc", request_type=metadata_service.DeleteContextRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4171,11 +3821,9 @@ def test_delete_context(transport: str = 'grpc', request_type=metadata_service.D request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_context(request) @@ -4197,25 +3845,24 @@ def test_delete_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: client.delete_context() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.DeleteContextRequest() + @pytest.mark.asyncio -async def test_delete_context_async(transport: str = 'grpc_asyncio', request_type=metadata_service.DeleteContextRequest): +async def test_delete_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.DeleteContextRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4223,12 +3870,10 @@ async def test_delete_context_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_context(request) @@ -4249,20 +3894,16 @@ async def test_delete_context_async_from_dict(): def test_delete_context_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_context(request) @@ -4273,28 +3914,23 @@ def test_delete_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_context_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteContextRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_context(request) @@ -4305,101 +3941,82 @@ async def test_delete_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_context_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_context( - name='name_value', - ) + client.delete_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_context_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_context( - metadata_service.DeleteContextRequest(), - name='name_value', + metadata_service.DeleteContextRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_context_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_context), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_context( - name='name_value', - ) + response = await client.delete_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_context_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_context( - metadata_service.DeleteContextRequest(), - name='name_value', + metadata_service.DeleteContextRequest(), name="name_value", ) -def test_add_context_artifacts_and_executions(transport: str = 'grpc', request_type=metadata_service.AddContextArtifactsAndExecutionsRequest): +def test_add_context_artifacts_and_executions( + transport: str = "grpc", + request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4408,11 +4025,10 @@ def test_add_context_artifacts_and_executions(transport: str = 'grpc', request_t # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse( - ) + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() response = client.add_context_artifacts_and_executions(request) @@ -4424,7 +4040,9 @@ def test_add_context_artifacts_and_executions(transport: str = 'grpc', request_t # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextArtifactsAndExecutionsResponse) + assert isinstance( + response, metadata_service.AddContextArtifactsAndExecutionsResponse + ) def test_add_context_artifacts_and_executions_from_dict(): @@ -4435,25 +4053,27 @@ def test_add_context_artifacts_and_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: client.add_context_artifacts_and_executions() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + @pytest.mark.asyncio -async def test_add_context_artifacts_and_executions_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddContextArtifactsAndExecutionsRequest): +async def test_add_context_artifacts_and_executions_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4462,11 +4082,12 @@ async def test_add_context_artifacts_and_executions_async(transport: str = 'grpc # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) response = await client.add_context_artifacts_and_executions(request) @@ -4477,7 +4098,9 @@ async def test_add_context_artifacts_and_executions_async(transport: str = 'grpc assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextArtifactsAndExecutionsResponse) + assert isinstance( + response, metadata_service.AddContextArtifactsAndExecutionsResponse + ) @pytest.mark.asyncio @@ -4486,19 +4109,17 @@ async def test_add_context_artifacts_and_executions_async_from_dict(): def test_add_context_artifacts_and_executions_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextArtifactsAndExecutionsRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() client.add_context_artifacts_and_executions(request) @@ -4510,28 +4131,25 @@ def test_add_context_artifacts_and_executions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextArtifactsAndExecutionsRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse()) + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) await client.add_context_artifacts_and_executions(request) @@ -4542,30 +4160,25 @@ async def test_add_context_artifacts_and_executions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] def test_add_context_artifacts_and_executions_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.add_context_artifacts_and_executions( - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) # Establish that the underlying call was made with the expected @@ -4573,49 +4186,47 @@ def test_add_context_artifacts_and_executions_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].artifacts == ['artifacts_value'] + assert args[0].artifacts == ["artifacts_value"] - assert args[0].executions == ['executions_value'] + assert args[0].executions == ["executions_value"] def test_add_context_artifacts_and_executions_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.add_context_artifacts_and_executions( metadata_service.AddContextArtifactsAndExecutionsRequest(), - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), - '__call__') as call: + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextArtifactsAndExecutionsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.add_context_artifacts_and_executions( - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) # Establish that the underlying call was made with the expected @@ -4623,34 +4234,33 @@ async def test_add_context_artifacts_and_executions_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].artifacts == ['artifacts_value'] + assert args[0].artifacts == ["artifacts_value"] - assert args[0].executions == ['executions_value'] + assert args[0].executions == ["executions_value"] @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.add_context_artifacts_and_executions( metadata_service.AddContextArtifactsAndExecutionsRequest(), - context='context_value', - artifacts=['artifacts_value'], - executions=['executions_value'], + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], ) -def test_add_context_children(transport: str = 'grpc', request_type=metadata_service.AddContextChildrenRequest): +def test_add_context_children( + transport: str = "grpc", request_type=metadata_service.AddContextChildrenRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4659,11 +4269,10 @@ def test_add_context_children(transport: str = 'grpc', request_type=metadata_ser # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextChildrenResponse( - ) + call.return_value = metadata_service.AddContextChildrenResponse() response = client.add_context_children(request) @@ -4686,25 +4295,27 @@ def test_add_context_children_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: client.add_context_children() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.AddContextChildrenRequest() + @pytest.mark.asyncio -async def test_add_context_children_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddContextChildrenRequest): +async def test_add_context_children_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddContextChildrenRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4713,11 +4324,12 @@ async def test_add_context_children_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextChildrenResponse() + ) response = await client.add_context_children(request) @@ -4737,19 +4349,17 @@ async def test_add_context_children_async_from_dict(): def test_add_context_children_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextChildrenRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: call.return_value = metadata_service.AddContextChildrenResponse() client.add_context_children(request) @@ -4761,28 +4371,25 @@ def test_add_context_children_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] @pytest.mark.asyncio async def test_add_context_children_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddContextChildrenRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse()) + type(client.transport.add_context_children), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextChildrenResponse() + ) await client.add_context_children(request) @@ -4793,29 +4400,23 @@ async def test_add_context_children_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] def test_add_context_children_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextChildrenResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.add_context_children( - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", child_contexts=["child_contexts_value"], ) # Establish that the underlying call was made with the expected @@ -4823,45 +4424,42 @@ def test_add_context_children_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].child_contexts == ['child_contexts_value'] + assert args[0].child_contexts == ["child_contexts_value"] def test_add_context_children_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.add_context_children( metadata_service.AddContextChildrenRequest(), - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", + child_contexts=["child_contexts_value"], ) @pytest.mark.asyncio async def test_add_context_children_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_children), - '__call__') as call: + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddContextChildrenResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddContextChildrenResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextChildrenResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.add_context_children( - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", child_contexts=["child_contexts_value"], ) # Establish that the underlying call was made with the expected @@ -4869,31 +4467,31 @@ async def test_add_context_children_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" - assert args[0].child_contexts == ['child_contexts_value'] + assert args[0].child_contexts == ["child_contexts_value"] @pytest.mark.asyncio async def test_add_context_children_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.add_context_children( metadata_service.AddContextChildrenRequest(), - context='context_value', - child_contexts=['child_contexts_value'], + context="context_value", + child_contexts=["child_contexts_value"], ) -def test_query_context_lineage_subgraph(transport: str = 'grpc', request_type=metadata_service.QueryContextLineageSubgraphRequest): +def test_query_context_lineage_subgraph( + transport: str = "grpc", + request_type=metadata_service.QueryContextLineageSubgraphRequest, +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4902,11 +4500,10 @@ def test_query_context_lineage_subgraph(transport: str = 'grpc', request_type=me # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph( - ) + call.return_value = lineage_subgraph.LineageSubgraph() response = client.query_context_lineage_subgraph(request) @@ -4929,25 +4526,27 @@ def test_query_context_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: client.query_context_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + @pytest.mark.asyncio -async def test_query_context_lineage_subgraph_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryContextLineageSubgraphRequest): +async def test_query_context_lineage_subgraph_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.QueryContextLineageSubgraphRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4956,11 +4555,12 @@ async def test_query_context_lineage_subgraph_async(transport: str = 'grpc_async # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) response = await client.query_context_lineage_subgraph(request) @@ -4980,19 +4580,17 @@ async def test_query_context_lineage_subgraph_async_from_dict(): def test_query_context_lineage_subgraph_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryContextLineageSubgraphRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: call.return_value = lineage_subgraph.LineageSubgraph() client.query_context_lineage_subgraph(request) @@ -5004,28 +4602,25 @@ def test_query_context_lineage_subgraph_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] @pytest.mark.asyncio async def test_query_context_lineage_subgraph_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryContextLineageSubgraphRequest() - request.context = 'context/value' + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) await client.query_context_lineage_subgraph(request) @@ -5036,99 +4631,87 @@ async def test_query_context_lineage_subgraph_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'context=context/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] def test_query_context_lineage_subgraph_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.query_context_lineage_subgraph( - context='context_value', - ) + client.query_context_lineage_subgraph(context="context_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" def test_query_context_lineage_subgraph_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.query_context_lineage_subgraph( metadata_service.QueryContextLineageSubgraphRequest(), - context='context_value', + context="context_value", ) @pytest.mark.asyncio async def test_query_context_lineage_subgraph_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), - '__call__') as call: + type(client.transport.query_context_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.query_context_lineage_subgraph( - context='context_value', - ) + response = await client.query_context_lineage_subgraph(context="context_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == 'context_value' + assert args[0].context == "context_value" @pytest.mark.asyncio async def test_query_context_lineage_subgraph_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.query_context_lineage_subgraph( metadata_service.QueryContextLineageSubgraphRequest(), - context='context_value', + context="context_value", ) -def test_create_execution(transport: str = 'grpc', request_type=metadata_service.CreateExecutionRequest): +def test_create_execution( + transport: str = "grpc", request_type=metadata_service.CreateExecutionRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5136,25 +4719,16 @@ def test_create_execution(transport: str = 'grpc', request_type=metadata_service request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=gca_execution.Execution.State.NEW, - - etag='etag_value', - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.create_execution(request) @@ -5169,19 +4743,19 @@ def test_create_execution(transport: str = 'grpc', request_type=metadata_service assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_execution_from_dict(): @@ -5192,25 +4766,25 @@ def test_create_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: client.create_execution() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateExecutionRequest() + @pytest.mark.asyncio -async def test_create_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateExecutionRequest): +async def test_create_execution_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.CreateExecutionRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5218,19 +4792,19 @@ async def test_create_execution_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution( - name='name_value', - display_name='display_name_value', - state=gca_execution.Execution.State.NEW, - etag='etag_value', - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution( + name="name_value", + display_name="display_name_value", + state=gca_execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.create_execution(request) @@ -5243,19 +4817,19 @@ async def test_create_execution_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -5264,19 +4838,15 @@ async def test_create_execution_async_from_dict(): def test_create_execution_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateExecutionRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: call.return_value = gca_execution.Execution() client.create_execution(request) @@ -5288,28 +4858,23 @@ def test_create_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_execution_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateExecutionRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) await client.create_execution(request) @@ -5320,30 +4885,23 @@ async def test_create_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_execution_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_execution( - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) # Establish that the underlying call was made with the expected @@ -5351,49 +4909,45 @@ def test_create_execution_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].execution_id == 'execution_id_value' + assert args[0].execution_id == "execution_id_value" def test_create_execution_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_execution( metadata_service.CreateExecutionRequest(), - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) @pytest.mark.asyncio async def test_create_execution_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_execution( - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) # Establish that the underlying call was made with the expected @@ -5401,34 +4955,33 @@ async def test_create_execution_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].execution_id == 'execution_id_value' + assert args[0].execution_id == "execution_id_value" @pytest.mark.asyncio async def test_create_execution_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_execution( metadata_service.CreateExecutionRequest(), - parent='parent_value', - execution=gca_execution.Execution(name='name_value'), - execution_id='execution_id_value', + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) -def test_get_execution(transport: str = 'grpc', request_type=metadata_service.GetExecutionRequest): +def test_get_execution( + transport: str = "grpc", request_type=metadata_service.GetExecutionRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5436,25 +4989,16 @@ def test_get_execution(transport: str = 'grpc', request_type=metadata_service.Ge request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = execution.Execution( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=execution.Execution.State.NEW, - - etag='etag_value', - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.get_execution(request) @@ -5469,19 +5013,19 @@ def test_get_execution(transport: str = 'grpc', request_type=metadata_service.Ge assert isinstance(response, execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_execution_from_dict(): @@ -5492,25 +5036,24 @@ def test_get_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: client.get_execution() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetExecutionRequest() + @pytest.mark.asyncio -async def test_get_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetExecutionRequest): +async def test_get_execution_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetExecutionRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5518,19 +5061,19 @@ async def test_get_execution_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution( - name='name_value', - display_name='display_name_value', - state=execution.Execution.State.NEW, - etag='etag_value', - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + execution.Execution( + name="name_value", + display_name="display_name_value", + state=execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.get_execution(request) @@ -5543,19 +5086,19 @@ async def test_get_execution_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -5564,19 +5107,15 @@ async def test_get_execution_async_from_dict(): def test_get_execution_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetExecutionRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: call.return_value = execution.Execution() client.get_execution(request) @@ -5588,27 +5127,20 @@ def test_get_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_execution_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetExecutionRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) await client.get_execution(request) @@ -5620,99 +5152,79 @@ async def test_get_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_execution_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_execution( - name='name_value', - ) + client.get_execution(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_execution_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_execution( - metadata_service.GetExecutionRequest(), - name='name_value', + metadata_service.GetExecutionRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_execution_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = execution.Execution() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_execution( - name='name_value', - ) + response = await client.get_execution(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_execution_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_execution( - metadata_service.GetExecutionRequest(), - name='name_value', + metadata_service.GetExecutionRequest(), name="name_value", ) -def test_list_executions(transport: str = 'grpc', request_type=metadata_service.ListExecutionsRequest): +def test_list_executions( + transport: str = "grpc", request_type=metadata_service.ListExecutionsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5720,13 +5232,10 @@ def test_list_executions(transport: str = 'grpc', request_type=metadata_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListExecutionsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_executions(request) @@ -5741,7 +5250,7 @@ def test_list_executions(transport: str = 'grpc', request_type=metadata_service. assert isinstance(response, pagers.ListExecutionsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_executions_from_dict(): @@ -5752,25 +5261,24 @@ def test_list_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: client.list_executions() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListExecutionsRequest() + @pytest.mark.asyncio -async def test_list_executions_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListExecutionsRequest): +async def test_list_executions_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListExecutionsRequest +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5778,13 +5286,13 @@ async def test_list_executions_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListExecutionsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_executions(request) @@ -5797,7 +5305,7 @@ async def test_list_executions_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListExecutionsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -5806,19 +5314,15 @@ async def test_list_executions_async_from_dict(): def test_list_executions_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListExecutionsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: call.return_value = metadata_service.ListExecutionsResponse() client.list_executions(request) @@ -5830,28 +5334,23 @@ def test_list_executions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_executions_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListExecutionsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse()) + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListExecutionsResponse() + ) await client.list_executions(request) @@ -5862,104 +5361,81 @@ async def test_list_executions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_executions_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListExecutionsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_executions( - parent='parent_value', - ) + client.list_executions(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_executions_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_executions( - metadata_service.ListExecutionsRequest(), - parent='parent_value', + metadata_service.ListExecutionsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_executions_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListExecutionsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListExecutionsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListExecutionsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_executions( - parent='parent_value', - ) + response = await client.list_executions(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_executions_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_executions( - metadata_service.ListExecutionsRequest(), - parent='parent_value', + metadata_service.ListExecutionsRequest(), parent="parent_value", ) def test_list_executions_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -5968,32 +5444,23 @@ def test_list_executions_pager(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_executions(request={}) @@ -6001,18 +5468,14 @@ def test_list_executions_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, execution.Execution) - for i in results) + assert all(isinstance(i, execution.Execution) for i in results) + def test_list_executions_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -6021,40 +5484,32 @@ def test_list_executions_pages(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) pages = list(client.list_executions(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_executions_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_executions), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -6063,46 +5518,37 @@ async def test_list_executions_async_pager(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) async_pager = await client.list_executions(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, execution.Execution) - for i in responses) + assert all(isinstance(i, execution.Execution) for i in responses) + @pytest.mark.asyncio async def test_list_executions_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_executions), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListExecutionsResponse( @@ -6111,37 +5557,31 @@ async def test_list_executions_async_pages(): execution.Execution(), execution.Execution(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListExecutionsResponse( - executions=[], - next_page_token='def', + executions=[], next_page_token="def", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - ], - next_page_token='ghi', + executions=[execution.Execution(),], next_page_token="ghi", ), metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - ], + executions=[execution.Execution(), execution.Execution(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_executions(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_execution(transport: str = 'grpc', request_type=metadata_service.UpdateExecutionRequest): +def test_update_execution( + transport: str = "grpc", request_type=metadata_service.UpdateExecutionRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6149,25 +5589,16 @@ def test_update_execution(transport: str = 'grpc', request_type=metadata_service request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=gca_execution.Execution.State.NEW, - - etag='etag_value', - - schema_title='schema_title_value', - - schema_version='schema_version_value', - - description='description_value', - + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) response = client.update_execution(request) @@ -6182,19 +5613,19 @@ def test_update_execution(transport: str = 'grpc', request_type=metadata_service assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" def test_update_execution_from_dict(): @@ -6205,25 +5636,25 @@ def test_update_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: client.update_execution() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateExecutionRequest() + @pytest.mark.asyncio -async def test_update_execution_async(transport: str = 'grpc_asyncio', request_type=metadata_service.UpdateExecutionRequest): +async def test_update_execution_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.UpdateExecutionRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6231,19 +5662,19 @@ async def test_update_execution_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution( - name='name_value', - display_name='display_name_value', - state=gca_execution.Execution.State.NEW, - etag='etag_value', - schema_title='schema_title_value', - schema_version='schema_version_value', - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution( + name="name_value", + display_name="display_name_value", + state=gca_execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) response = await client.update_execution(request) @@ -6256,19 +5687,19 @@ async def test_update_execution_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, gca_execution.Execution) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == gca_execution.Execution.State.NEW - assert response.etag == 'etag_value' + assert response.etag == "etag_value" - assert response.schema_title == 'schema_title_value' + assert response.schema_title == "schema_title_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -6277,19 +5708,15 @@ async def test_update_execution_async_from_dict(): def test_update_execution_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateExecutionRequest() - request.execution.name = 'execution.name/value' + request.execution.name = "execution.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: call.return_value = gca_execution.Execution() client.update_execution(request) @@ -6301,28 +5728,25 @@ def test_update_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution.name=execution.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_execution_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.UpdateExecutionRequest() - request.execution.name = 'execution.name/value' + request.execution.name = "execution.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) await client.update_execution(request) @@ -6333,29 +5757,24 @@ async def test_update_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution.name=execution.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ + "metadata" + ] def test_update_execution_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_execution( - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -6363,45 +5782,41 @@ def test_update_execution_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_execution_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_execution( metadata_service.UpdateExecutionRequest(), - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_execution_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_execution), - '__call__') as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_execution.Execution() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_execution.Execution()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_execution( - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -6409,31 +5824,30 @@ async def test_update_execution_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].execution == gca_execution.Execution(name='name_value') + assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_execution_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_execution( metadata_service.UpdateExecutionRequest(), - execution=gca_execution.Execution(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_add_execution_events(transport: str = 'grpc', request_type=metadata_service.AddExecutionEventsRequest): +def test_add_execution_events( + transport: str = "grpc", request_type=metadata_service.AddExecutionEventsRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6442,11 +5856,10 @@ def test_add_execution_events(transport: str = 'grpc', request_type=metadata_ser # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddExecutionEventsResponse( - ) + call.return_value = metadata_service.AddExecutionEventsResponse() response = client.add_execution_events(request) @@ -6469,25 +5882,27 @@ def test_add_execution_events_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: client.add_execution_events() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.AddExecutionEventsRequest() + @pytest.mark.asyncio -async def test_add_execution_events_async(transport: str = 'grpc_asyncio', request_type=metadata_service.AddExecutionEventsRequest): +async def test_add_execution_events_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddExecutionEventsRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6496,11 +5911,12 @@ async def test_add_execution_events_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddExecutionEventsResponse() + ) response = await client.add_execution_events(request) @@ -6520,19 +5936,17 @@ async def test_add_execution_events_async_from_dict(): def test_add_execution_events_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddExecutionEventsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: call.return_value = metadata_service.AddExecutionEventsResponse() client.add_execution_events(request) @@ -6544,28 +5958,25 @@ def test_add_execution_events_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] @pytest.mark.asyncio async def test_add_execution_events_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.AddExecutionEventsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse()) + type(client.transport.add_execution_events), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddExecutionEventsResponse() + ) await client.add_execution_events(request) @@ -6576,29 +5987,24 @@ async def test_add_execution_events_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] def test_add_execution_events_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddExecutionEventsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.add_execution_events( - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) # Establish that the underlying call was made with the expected @@ -6606,45 +6012,43 @@ def test_add_execution_events_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" - assert args[0].events == [event.Event(artifact='artifact_value')] + assert args[0].events == [event.Event(artifact="artifact_value")] def test_add_execution_events_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.add_execution_events( metadata_service.AddExecutionEventsRequest(), - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) @pytest.mark.asyncio async def test_add_execution_events_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_execution_events), - '__call__') as call: + type(client.transport.add_execution_events), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.AddExecutionEventsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.AddExecutionEventsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddExecutionEventsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.add_execution_events( - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) # Establish that the underlying call was made with the expected @@ -6652,31 +6056,31 @@ async def test_add_execution_events_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" - assert args[0].events == [event.Event(artifact='artifact_value')] + assert args[0].events == [event.Event(artifact="artifact_value")] @pytest.mark.asyncio async def test_add_execution_events_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.add_execution_events( metadata_service.AddExecutionEventsRequest(), - execution='execution_value', - events=[event.Event(artifact='artifact_value')], + execution="execution_value", + events=[event.Event(artifact="artifact_value")], ) -def test_query_execution_inputs_and_outputs(transport: str = 'grpc', request_type=metadata_service.QueryExecutionInputsAndOutputsRequest): +def test_query_execution_inputs_and_outputs( + transport: str = "grpc", + request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6685,11 +6089,10 @@ def test_query_execution_inputs_and_outputs(transport: str = 'grpc', request_typ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph( - ) + call.return_value = lineage_subgraph.LineageSubgraph() response = client.query_execution_inputs_and_outputs(request) @@ -6712,25 +6115,27 @@ def test_query_execution_inputs_and_outputs_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: client.query_execution_inputs_and_outputs() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + @pytest.mark.asyncio -async def test_query_execution_inputs_and_outputs_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryExecutionInputsAndOutputsRequest): +async def test_query_execution_inputs_and_outputs_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6739,11 +6144,12 @@ async def test_query_execution_inputs_and_outputs_async(transport: str = 'grpc_a # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) response = await client.query_execution_inputs_and_outputs(request) @@ -6763,19 +6169,17 @@ async def test_query_execution_inputs_and_outputs_async_from_dict(): def test_query_execution_inputs_and_outputs_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryExecutionInputsAndOutputsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: call.return_value = lineage_subgraph.LineageSubgraph() client.query_execution_inputs_and_outputs(request) @@ -6787,28 +6191,25 @@ def test_query_execution_inputs_and_outputs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryExecutionInputsAndOutputsRequest() - request.execution = 'execution/value' + request.execution = "execution/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) await client.query_execution_inputs_and_outputs(request) @@ -6819,70 +6220,61 @@ async def test_query_execution_inputs_and_outputs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'execution=execution/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "execution=execution/value",) in kw["metadata"] def test_query_execution_inputs_and_outputs_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.query_execution_inputs_and_outputs( - execution='execution_value', - ) + client.query_execution_inputs_and_outputs(execution="execution_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" def test_query_execution_inputs_and_outputs_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.query_execution_inputs_and_outputs( metadata_service.QueryExecutionInputsAndOutputsRequest(), - execution='execution_value', + execution="execution_value", ) @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), - '__call__') as call: + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.query_execution_inputs_and_outputs( - execution='execution_value', + execution="execution_value", ) # Establish that the underlying call was made with the expected @@ -6890,28 +6282,27 @@ async def test_query_execution_inputs_and_outputs_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].execution == 'execution_value' + assert args[0].execution == "execution_value" @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.query_execution_inputs_and_outputs( metadata_service.QueryExecutionInputsAndOutputsRequest(), - execution='execution_value', + execution="execution_value", ) -def test_create_metadata_schema(transport: str = 'grpc', request_type=metadata_service.CreateMetadataSchemaRequest): +def test_create_metadata_schema( + transport: str = "grpc", request_type=metadata_service.CreateMetadataSchemaRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6920,20 +6311,15 @@ def test_create_metadata_schema(transport: str = 'grpc', request_type=metadata_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_metadata_schema.MetadataSchema( - name='name_value', - - schema_version='schema_version_value', - - schema='schema_value', - + name="name_value", + schema_version="schema_version_value", + schema="schema_value", schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - - description='description_value', - + description="description_value", ) response = client.create_metadata_schema(request) @@ -6948,15 +6334,18 @@ def test_create_metadata_schema(transport: str = 'grpc', request_type=metadata_s assert isinstance(response, gca_metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" def test_create_metadata_schema_from_dict(): @@ -6967,25 +6356,27 @@ def test_create_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: client.create_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.CreateMetadataSchemaRequest() + @pytest.mark.asyncio -async def test_create_metadata_schema_async(transport: str = 'grpc_asyncio', request_type=metadata_service.CreateMetadataSchemaRequest): +async def test_create_metadata_schema_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.CreateMetadataSchemaRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -6994,16 +6385,18 @@ async def test_create_metadata_schema_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema( - name='name_value', - schema_version='schema_version_value', - schema='schema_value', - schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_metadata_schema.MetadataSchema( + name="name_value", + schema_version="schema_version_value", + schema="schema_value", + schema_type=gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + description="description_value", + ) + ) response = await client.create_metadata_schema(request) @@ -7016,15 +6409,18 @@ async def test_create_metadata_schema_async(transport: str = 'grpc_asyncio', req # Establish that the response is the type that we expect. assert isinstance(response, gca_metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == gca_metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -7033,19 +6429,17 @@ async def test_create_metadata_schema_async_from_dict(): def test_create_metadata_schema_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataSchemaRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: call.return_value = gca_metadata_schema.MetadataSchema() client.create_metadata_schema(request) @@ -7057,28 +6451,25 @@ def test_create_metadata_schema_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_metadata_schema_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.CreateMetadataSchemaRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema()) + type(client.transport.create_metadata_schema), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_metadata_schema.MetadataSchema() + ) await client.create_metadata_schema(request) @@ -7089,30 +6480,25 @@ async def test_create_metadata_schema_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_metadata_schema_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_metadata_schema.MetadataSchema() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_metadata_schema( - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) # Establish that the underlying call was made with the expected @@ -7120,49 +6506,49 @@ def test_create_metadata_schema_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema(name='name_value') + assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema( + name="name_value" + ) - assert args[0].metadata_schema_id == 'metadata_schema_id_value' + assert args[0].metadata_schema_id == "metadata_schema_id_value" def test_create_metadata_schema_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_metadata_schema( metadata_service.CreateMetadataSchemaRequest(), - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) @pytest.mark.asyncio async def test_create_metadata_schema_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_metadata_schema), - '__call__') as call: + type(client.transport.create_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_metadata_schema.MetadataSchema() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_metadata_schema.MetadataSchema()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_metadata_schema.MetadataSchema() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_metadata_schema( - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) # Establish that the underlying call was made with the expected @@ -7170,34 +6556,35 @@ async def test_create_metadata_schema_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema(name='name_value') + assert args[0].metadata_schema == gca_metadata_schema.MetadataSchema( + name="name_value" + ) - assert args[0].metadata_schema_id == 'metadata_schema_id_value' + assert args[0].metadata_schema_id == "metadata_schema_id_value" @pytest.mark.asyncio async def test_create_metadata_schema_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_metadata_schema( metadata_service.CreateMetadataSchemaRequest(), - parent='parent_value', - metadata_schema=gca_metadata_schema.MetadataSchema(name='name_value'), - metadata_schema_id='metadata_schema_id_value', + parent="parent_value", + metadata_schema=gca_metadata_schema.MetadataSchema(name="name_value"), + metadata_schema_id="metadata_schema_id_value", ) -def test_get_metadata_schema(transport: str = 'grpc', request_type=metadata_service.GetMetadataSchemaRequest): +def test_get_metadata_schema( + transport: str = "grpc", request_type=metadata_service.GetMetadataSchemaRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7206,20 +6593,15 @@ def test_get_metadata_schema(transport: str = 'grpc', request_type=metadata_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_schema.MetadataSchema( - name='name_value', - - schema_version='schema_version_value', - - schema='schema_value', - + name="name_value", + schema_version="schema_version_value", + schema="schema_value", schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - - description='description_value', - + description="description_value", ) response = client.get_metadata_schema(request) @@ -7234,15 +6616,18 @@ def test_get_metadata_schema(transport: str = 'grpc', request_type=metadata_serv assert isinstance(response, metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" def test_get_metadata_schema_from_dict(): @@ -7253,25 +6638,27 @@ def test_get_metadata_schema_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: client.get_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.GetMetadataSchemaRequest() + @pytest.mark.asyncio -async def test_get_metadata_schema_async(transport: str = 'grpc_asyncio', request_type=metadata_service.GetMetadataSchemaRequest): +async def test_get_metadata_schema_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.GetMetadataSchemaRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7280,16 +6667,18 @@ async def test_get_metadata_schema_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema( - name='name_value', - schema_version='schema_version_value', - schema='schema_value', - schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, - description='description_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_schema.MetadataSchema( + name="name_value", + schema_version="schema_version_value", + schema="schema_value", + schema_type=metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE, + description="description_value", + ) + ) response = await client.get_metadata_schema(request) @@ -7302,15 +6691,18 @@ async def test_get_metadata_schema_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, metadata_schema.MetadataSchema) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.schema_version == 'schema_version_value' + assert response.schema_version == "schema_version_value" - assert response.schema == 'schema_value' + assert response.schema == "schema_value" - assert response.schema_type == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + assert ( + response.schema_type + == metadata_schema.MetadataSchema.MetadataSchemaType.ARTIFACT_TYPE + ) - assert response.description == 'description_value' + assert response.description == "description_value" @pytest.mark.asyncio @@ -7319,19 +6711,17 @@ async def test_get_metadata_schema_async_from_dict(): def test_get_metadata_schema_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataSchemaRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: call.return_value = metadata_schema.MetadataSchema() client.get_metadata_schema(request) @@ -7343,28 +6733,25 @@ def test_get_metadata_schema_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_metadata_schema_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.GetMetadataSchemaRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema()) + type(client.transport.get_metadata_schema), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_schema.MetadataSchema() + ) await client.get_metadata_schema(request) @@ -7375,99 +6762,85 @@ async def test_get_metadata_schema_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_metadata_schema_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_schema.MetadataSchema() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_metadata_schema( - name='name_value', - ) + client.get_metadata_schema(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_metadata_schema_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_metadata_schema( - metadata_service.GetMetadataSchemaRequest(), - name='name_value', + metadata_service.GetMetadataSchemaRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_metadata_schema_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_metadata_schema), - '__call__') as call: + type(client.transport.get_metadata_schema), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_schema.MetadataSchema() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_schema.MetadataSchema()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_schema.MetadataSchema() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_metadata_schema( - name='name_value', - ) + response = await client.get_metadata_schema(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_metadata_schema_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_metadata_schema( - metadata_service.GetMetadataSchemaRequest(), - name='name_value', + metadata_service.GetMetadataSchemaRequest(), name="name_value", ) -def test_list_metadata_schemas(transport: str = 'grpc', request_type=metadata_service.ListMetadataSchemasRequest): +def test_list_metadata_schemas( + transport: str = "grpc", request_type=metadata_service.ListMetadataSchemasRequest +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7476,12 +6849,11 @@ def test_list_metadata_schemas(transport: str = 'grpc', request_type=metadata_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataSchemasResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_metadata_schemas(request) @@ -7496,7 +6868,7 @@ def test_list_metadata_schemas(transport: str = 'grpc', request_type=metadata_se assert isinstance(response, pagers.ListMetadataSchemasPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_metadata_schemas_from_dict(): @@ -7507,25 +6879,27 @@ def test_list_metadata_schemas_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: client.list_metadata_schemas() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.ListMetadataSchemasRequest() + @pytest.mark.asyncio -async def test_list_metadata_schemas_async(transport: str = 'grpc_asyncio', request_type=metadata_service.ListMetadataSchemasRequest): +async def test_list_metadata_schemas_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.ListMetadataSchemasRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7534,12 +6908,14 @@ async def test_list_metadata_schemas_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataSchemasResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_metadata_schemas(request) @@ -7552,7 +6928,7 @@ async def test_list_metadata_schemas_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListMetadataSchemasAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -7561,19 +6937,17 @@ async def test_list_metadata_schemas_async_from_dict(): def test_list_metadata_schemas_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataSchemasRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: call.return_value = metadata_service.ListMetadataSchemasResponse() client.list_metadata_schemas(request) @@ -7585,28 +6959,25 @@ def test_list_metadata_schemas_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_metadata_schemas_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.ListMetadataSchemasRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse()) + type(client.transport.list_metadata_schemas), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataSchemasResponse() + ) await client.list_metadata_schemas(request) @@ -7617,104 +6988,87 @@ async def test_list_metadata_schemas_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_metadata_schemas_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataSchemasResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_metadata_schemas( - parent='parent_value', - ) + client.list_metadata_schemas(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_metadata_schemas_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_metadata_schemas( - metadata_service.ListMetadataSchemasRequest(), - parent='parent_value', + metadata_service.ListMetadataSchemasRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_metadata_schemas_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = metadata_service.ListMetadataSchemasResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(metadata_service.ListMetadataSchemasResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListMetadataSchemasResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_metadata_schemas( - parent='parent_value', - ) + response = await client.list_metadata_schemas(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_metadata_schemas_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_metadata_schemas( - metadata_service.ListMetadataSchemasRequest(), - parent='parent_value', + metadata_service.ListMetadataSchemasRequest(), parent="parent_value", ) def test_list_metadata_schemas_pager(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7723,17 +7077,14 @@ def test_list_metadata_schemas_pager(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7746,9 +7097,7 @@ def test_list_metadata_schemas_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_metadata_schemas(request={}) @@ -7756,18 +7105,16 @@ def test_list_metadata_schemas_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, metadata_schema.MetadataSchema) - for i in results) + assert all(isinstance(i, metadata_schema.MetadataSchema) for i in results) + def test_list_metadata_schemas_pages(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__') as call: + type(client.transport.list_metadata_schemas), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7776,17 +7123,14 @@ def test_list_metadata_schemas_pages(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7797,19 +7141,20 @@ def test_list_metadata_schemas_pages(): RuntimeError, ) pages = list(client.list_metadata_schemas(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_metadata_schemas_async_pager(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_schemas), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7818,17 +7163,14 @@ async def test_list_metadata_schemas_async_pager(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7839,25 +7181,25 @@ async def test_list_metadata_schemas_async_pager(): RuntimeError, ) async_pager = await client.list_metadata_schemas(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, metadata_schema.MetadataSchema) - for i in responses) + assert all(isinstance(i, metadata_schema.MetadataSchema) for i in responses) + @pytest.mark.asyncio async def test_list_metadata_schemas_async_pages(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_metadata_schemas), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_metadata_schemas), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( metadata_service.ListMetadataSchemasResponse( @@ -7866,17 +7208,14 @@ async def test_list_metadata_schemas_async_pages(): metadata_schema.MetadataSchema(), metadata_schema.MetadataSchema(), ], - next_page_token='abc', + next_page_token="abc", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[], - next_page_token='def', + metadata_schemas=[], next_page_token="def", ), metadata_service.ListMetadataSchemasResponse( - metadata_schemas=[ - metadata_schema.MetadataSchema(), - ], - next_page_token='ghi', + metadata_schemas=[metadata_schema.MetadataSchema(),], + next_page_token="ghi", ), metadata_service.ListMetadataSchemasResponse( metadata_schemas=[ @@ -7889,14 +7228,16 @@ async def test_list_metadata_schemas_async_pages(): pages = [] async for page_ in (await client.list_metadata_schemas(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_query_artifact_lineage_subgraph(transport: str = 'grpc', request_type=metadata_service.QueryArtifactLineageSubgraphRequest): +def test_query_artifact_lineage_subgraph( + transport: str = "grpc", + request_type=metadata_service.QueryArtifactLineageSubgraphRequest, +): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7905,11 +7246,10 @@ def test_query_artifact_lineage_subgraph(transport: str = 'grpc', request_type=m # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_artifact_lineage_subgraph), - '__call__') as call: + type(client.transport.query_artifact_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph( - ) + call.return_value = lineage_subgraph.LineageSubgraph() response = client.query_artifact_lineage_subgraph(request) @@ -7932,25 +7272,27 @@ def test_query_artifact_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_artifact_lineage_subgraph), - '__call__') as call: + type(client.transport.query_artifact_lineage_subgraph), "__call__" + ) as call: client.query_artifact_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.QueryArtifactLineageSubgraphRequest() + @pytest.mark.asyncio -async def test_query_artifact_lineage_subgraph_async(transport: str = 'grpc_asyncio', request_type=metadata_service.QueryArtifactLineageSubgraphRequest): +async def test_query_artifact_lineage_subgraph_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.QueryArtifactLineageSubgraphRequest, +): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -7959,11 +7301,12 @@ async def test_query_artifact_lineage_subgraph_async(transport: str = 'grpc_asyn # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_artifact_lineage_subgraph), - '__call__') as call: + type(client.transport.query_artifact_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) response = await client.query_artifact_lineage_subgraph(request) @@ -7983,19 +7326,17 @@ async def test_query_artifact_lineage_subgraph_async_from_dict(): def test_query_artifact_lineage_subgraph_field_headers(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryArtifactLineageSubgraphRequest() - request.artifact = 'artifact/value' + request.artifact = "artifact/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_artifact_lineage_subgraph), - '__call__') as call: + type(client.transport.query_artifact_lineage_subgraph), "__call__" + ) as call: call.return_value = lineage_subgraph.LineageSubgraph() client.query_artifact_lineage_subgraph(request) @@ -8007,28 +7348,25 @@ def test_query_artifact_lineage_subgraph_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'artifact=artifact/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "artifact=artifact/value",) in kw["metadata"] @pytest.mark.asyncio async def test_query_artifact_lineage_subgraph_field_headers_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.QueryArtifactLineageSubgraphRequest() - request.artifact = 'artifact/value' + request.artifact = "artifact/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_artifact_lineage_subgraph), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + type(client.transport.query_artifact_lineage_subgraph), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) await client.query_artifact_lineage_subgraph(request) @@ -8039,70 +7377,61 @@ async def test_query_artifact_lineage_subgraph_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'artifact=artifact/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "artifact=artifact/value",) in kw["metadata"] def test_query_artifact_lineage_subgraph_flattened(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_artifact_lineage_subgraph), - '__call__') as call: + type(client.transport.query_artifact_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.query_artifact_lineage_subgraph( - artifact='artifact_value', - ) + client.query_artifact_lineage_subgraph(artifact="artifact_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].artifact == 'artifact_value' + assert args[0].artifact == "artifact_value" def test_query_artifact_lineage_subgraph_flattened_error(): - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.query_artifact_lineage_subgraph( metadata_service.QueryArtifactLineageSubgraphRequest(), - artifact='artifact_value', + artifact="artifact_value", ) @pytest.mark.asyncio async def test_query_artifact_lineage_subgraph_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.query_artifact_lineage_subgraph), - '__call__') as call: + type(client.transport.query_artifact_lineage_subgraph), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = lineage_subgraph.LineageSubgraph() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(lineage_subgraph.LineageSubgraph()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.query_artifact_lineage_subgraph( - artifact='artifact_value', + artifact="artifact_value", ) # Establish that the underlying call was made with the expected @@ -8110,21 +7439,19 @@ async def test_query_artifact_lineage_subgraph_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].artifact == 'artifact_value' + assert args[0].artifact == "artifact_value" @pytest.mark.asyncio async def test_query_artifact_lineage_subgraph_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MetadataServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.query_artifact_lineage_subgraph( metadata_service.QueryArtifactLineageSubgraphRequest(), - artifact='artifact_value', + artifact="artifact_value", ) @@ -8135,8 +7462,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -8155,8 +7481,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MetadataServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -8184,13 +7509,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MetadataServiceGrpcTransport, - transports.MetadataServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -8198,13 +7526,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MetadataServiceGrpcTransport, - ) + client = MetadataServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MetadataServiceGrpcTransport,) def test_metadata_service_base_transport_error(): @@ -8212,13 +7535,15 @@ def test_metadata_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MetadataServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_metadata_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MetadataServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -8227,33 +7552,33 @@ def test_metadata_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_metadata_store', - 'get_metadata_store', - 'list_metadata_stores', - 'delete_metadata_store', - 'create_artifact', - 'get_artifact', - 'list_artifacts', - 'update_artifact', - 'create_context', - 'get_context', - 'list_contexts', - 'update_context', - 'delete_context', - 'add_context_artifacts_and_executions', - 'add_context_children', - 'query_context_lineage_subgraph', - 'create_execution', - 'get_execution', - 'list_executions', - 'update_execution', - 'add_execution_events', - 'query_execution_inputs_and_outputs', - 'create_metadata_schema', - 'get_metadata_schema', - 'list_metadata_schemas', - 'query_artifact_lineage_subgraph', - ) + "create_metadata_store", + "get_metadata_store", + "list_metadata_stores", + "delete_metadata_store", + "create_artifact", + "get_artifact", + "list_artifacts", + "update_artifact", + "create_context", + "get_context", + "list_contexts", + "update_context", + "delete_context", + "add_context_artifacts_and_executions", + "add_context_children", + "query_context_lineage_subgraph", + "create_execution", + "get_execution", + "list_executions", + "update_execution", + "add_execution_events", + "query_execution_inputs_and_outputs", + "create_metadata_schema", + "get_metadata_schema", + "list_metadata_schemas", + "query_artifact_lineage_subgraph", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -8266,23 +7591,28 @@ def test_metadata_service_base_transport(): def test_metadata_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MetadataServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_metadata_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.metadata_service.transports.MetadataServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MetadataServiceTransport() @@ -8291,11 +7621,11 @@ def test_metadata_service_base_transport_with_adc(): def test_metadata_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MetadataServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -8303,19 +7633,25 @@ def test_metadata_service_auth_adc(): def test_metadata_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MetadataServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MetadataServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) -def test_metadata_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) +def test_metadata_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -8324,15 +7660,13 @@ def test_metadata_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -8347,38 +7681,40 @@ def test_metadata_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_metadata_service_host_no_port(): client = MetadataServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_metadata_service_host_with_port(): client = MetadataServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_metadata_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MetadataServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -8386,12 +7722,11 @@ def test_metadata_service_grpc_transport_channel(): def test_metadata_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MetadataServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -8400,12 +7735,22 @@ def test_metadata_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) def test_metadata_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -8414,7 +7759,7 @@ def test_metadata_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -8430,9 +7775,7 @@ def test_metadata_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -8446,17 +7789,23 @@ def test_metadata_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MetadataServiceGrpcTransport, transports.MetadataServiceGrpcAsyncIOTransport]) -def test_metadata_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MetadataServiceGrpcTransport, + transports.MetadataServiceGrpcAsyncIOTransport, + ], +) +def test_metadata_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -8473,9 +7822,7 @@ def test_metadata_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -8488,16 +7835,12 @@ def test_metadata_service_transport_channel_mtls_with_adc( def test_metadata_service_grpc_lro_client(): client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -8505,16 +7848,12 @@ def test_metadata_service_grpc_lro_client(): def test_metadata_service_grpc_lro_async_client(): client = MetadataServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -8526,18 +7865,24 @@ def test_artifact_path(): metadata_store = "whelk" artifact = "octopus" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format(project=project, location=location, metadata_store=metadata_store, artifact=artifact, ) - actual = MetadataServiceClient.artifact_path(project, location, metadata_store, artifact) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) + actual = MetadataServiceClient.artifact_path( + project, location, metadata_store, artifact + ) assert expected == actual def test_parse_artifact_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "metadata_store": "cuttlefish", - "artifact": "mussel", - + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "artifact": "mussel", } path = MetadataServiceClient.artifact_path(**expected) @@ -8545,24 +7890,31 @@ def test_parse_artifact_path(): actual = MetadataServiceClient.parse_artifact_path(path) assert expected == actual + def test_context_path(): project = "winkle" location = "nautilus" metadata_store = "scallop" context = "abalone" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format(project=project, location=location, metadata_store=metadata_store, context=context, ) - actual = MetadataServiceClient.context_path(project, location, metadata_store, context) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + actual = MetadataServiceClient.context_path( + project, location, metadata_store, context + ) assert expected == actual def test_parse_context_path(): expected = { - "project": "squid", - "location": "clam", - "metadata_store": "whelk", - "context": "octopus", - + "project": "squid", + "location": "clam", + "metadata_store": "whelk", + "context": "octopus", } path = MetadataServiceClient.context_path(**expected) @@ -8570,24 +7922,31 @@ def test_parse_context_path(): actual = MetadataServiceClient.parse_context_path(path) assert expected == actual + def test_execution_path(): project = "oyster" location = "nudibranch" metadata_store = "cuttlefish" execution = "mussel" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format(project=project, location=location, metadata_store=metadata_store, execution=execution, ) - actual = MetadataServiceClient.execution_path(project, location, metadata_store, execution) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) + actual = MetadataServiceClient.execution_path( + project, location, metadata_store, execution + ) assert expected == actual def test_parse_execution_path(): expected = { - "project": "winkle", - "location": "nautilus", - "metadata_store": "scallop", - "execution": "abalone", - + "project": "winkle", + "location": "nautilus", + "metadata_store": "scallop", + "execution": "abalone", } path = MetadataServiceClient.execution_path(**expected) @@ -8595,24 +7954,31 @@ def test_parse_execution_path(): actual = MetadataServiceClient.parse_execution_path(path) assert expected == actual + def test_metadata_schema_path(): project = "squid" location = "clam" metadata_store = "whelk" metadata_schema = "octopus" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format(project=project, location=location, metadata_store=metadata_store, metadata_schema=metadata_schema, ) - actual = MetadataServiceClient.metadata_schema_path(project, location, metadata_store, metadata_schema) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/metadataSchemas/{metadata_schema}".format( + project=project, + location=location, + metadata_store=metadata_store, + metadata_schema=metadata_schema, + ) + actual = MetadataServiceClient.metadata_schema_path( + project, location, metadata_store, metadata_schema + ) assert expected == actual def test_parse_metadata_schema_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "metadata_store": "cuttlefish", - "metadata_schema": "mussel", - + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "metadata_schema": "mussel", } path = MetadataServiceClient.metadata_schema_path(**expected) @@ -8620,22 +7986,26 @@ def test_parse_metadata_schema_path(): actual = MetadataServiceClient.parse_metadata_schema_path(path) assert expected == actual + def test_metadata_store_path(): project = "winkle" location = "nautilus" metadata_store = "scallop" - expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format(project=project, location=location, metadata_store=metadata_store, ) - actual = MetadataServiceClient.metadata_store_path(project, location, metadata_store) + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}".format( + project=project, location=location, metadata_store=metadata_store, + ) + actual = MetadataServiceClient.metadata_store_path( + project, location, metadata_store + ) assert expected == actual def test_parse_metadata_store_path(): expected = { - "project": "abalone", - "location": "squid", - "metadata_store": "clam", - + "project": "abalone", + "location": "squid", + "metadata_store": "clam", } path = MetadataServiceClient.metadata_store_path(**expected) @@ -8643,18 +8013,20 @@ def test_parse_metadata_store_path(): actual = MetadataServiceClient.parse_metadata_store_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "whelk" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MetadataServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "octopus", - + "billing_account": "octopus", } path = MetadataServiceClient.common_billing_account_path(**expected) @@ -8662,18 +8034,18 @@ def test_parse_common_billing_account_path(): actual = MetadataServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "oyster" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MetadataServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nudibranch", - + "folder": "nudibranch", } path = MetadataServiceClient.common_folder_path(**expected) @@ -8681,18 +8053,18 @@ def test_parse_common_folder_path(): actual = MetadataServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "cuttlefish" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MetadataServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "mussel", - + "organization": "mussel", } path = MetadataServiceClient.common_organization_path(**expected) @@ -8700,18 +8072,18 @@ def test_parse_common_organization_path(): actual = MetadataServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "winkle" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MetadataServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "nautilus", - + "project": "nautilus", } path = MetadataServiceClient.common_project_path(**expected) @@ -8719,20 +8091,22 @@ def test_parse_common_project_path(): actual = MetadataServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "scallop" location = "abalone" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MetadataServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "squid", - "location": "clam", - + "project": "squid", + "location": "clam", } path = MetadataServiceClient.common_location_path(**expected) @@ -8744,17 +8118,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MetadataServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MetadataServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MetadataServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MetadataServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MetadataServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MetadataServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 51d76cb3c4..f547beb6bf 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceClient +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceClient, +) from google.cloud.aiplatform_v1beta1.services.migration_service import pagers from google.cloud.aiplatform_v1beta1.services.migration_service import transports from google.cloud.aiplatform_v1beta1.types import migratable_resource @@ -53,7 +57,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -64,36 +72,53 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info) assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ - MigrationServiceClient, - MigrationServiceAsyncClient, -]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -103,7 +128,7 @@ def test_migration_service_client_from_service_account_file(client_class): assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_client_get_transport_class(): @@ -117,29 +142,44 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) -def test_migration_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -155,7 +195,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -171,7 +211,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -191,13 +231,15 @@ def test_migration_service_client_client_options(client_class, transport_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -210,26 +252,62 @@ def test_migration_service_client_client_options(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) @@ -252,10 +330,18 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -276,9 +362,14 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -292,16 +383,23 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -314,16 +412,24 @@ def test_migration_service_client_client_options_scopes(client_class, transport_ client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), -]) -def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -338,10 +444,12 @@ def test_migration_service_client_client_options_credentials_file(client_class, def test_migration_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -354,10 +462,12 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -366,12 +476,11 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_migratable_resources(request) @@ -386,7 +495,7 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_migratable_resources_from_dict(): @@ -397,25 +506,27 @@ def test_search_migratable_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.SearchMigratableResourcesRequest() + @pytest.mark.asyncio -async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): +async def test_search_migratable_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -424,12 +535,14 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.search_migratable_resources(request) @@ -442,7 +555,7 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -451,19 +564,17 @@ async def test_search_migratable_resources_async_from_dict(): def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -475,10 +586,7 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -490,13 +598,15 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) await client.search_migratable_resources(request) @@ -507,49 +617,39 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources( - parent='parent_value', - ) + client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) @@ -561,24 +661,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources( - parent='parent_value', - ) + response = await client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -591,20 +691,17 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -613,17 +710,14 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -636,9 +730,7 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.search_migratable_resources(request={}) @@ -646,18 +738,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in results) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + def test_search_migratable_resources_pages(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -666,17 +758,14 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -687,19 +776,20 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -708,17 +798,14 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -729,25 +816,27 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in responses) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -756,17 +845,14 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -779,14 +865,15 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -795,10 +882,10 @@ def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_migrate_resources(request) @@ -820,25 +907,27 @@ def test_batch_migrate_resources_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == migration_service.BatchMigrateResourcesRequest() + @pytest.mark.asyncio -async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): +async def test_batch_migrate_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.BatchMigrateResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -847,11 +936,11 @@ async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_migrate_resources(request) @@ -872,20 +961,18 @@ async def test_batch_migrate_resources_async_from_dict(): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_migrate_resources(request) @@ -896,10 +983,7 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -911,13 +995,15 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_migrate_resources(request) @@ -928,29 +1014,30 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -958,23 +1045,33 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -986,19 +1083,25 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -1006,9 +1109,15 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] @pytest.mark.asyncio @@ -1022,8 +1131,14 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -1034,8 +1149,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1054,8 +1168,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1083,13 +1196,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1097,13 +1213,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MigrationServiceGrpcTransport, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) def test_migration_service_base_transport_error(): @@ -1111,13 +1222,15 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1126,9 +1239,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'search_migratable_resources', - 'batch_migrate_resources', - ) + "search_migratable_resources", + "batch_migrate_resources", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1141,23 +1254,28 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1166,11 +1284,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1178,19 +1296,25 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1199,15 +1323,13 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1222,38 +1344,40 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_migration_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1261,12 +1385,11 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1275,12 +1398,22 @@ def test_migration_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1289,7 +1422,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1305,9 +1438,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1321,17 +1452,23 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1348,9 +1485,7 @@ def test_migration_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1363,16 +1498,12 @@ def test_migration_service_transport_channel_mtls_with_adc( def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1380,16 +1511,12 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1400,17 +1527,20 @@ def test_annotated_dataset_path(): dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) - actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", - + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1418,22 +1548,24 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1441,20 +1573,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "squid" dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", - + "project": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1462,22 +1596,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "oyster" location = "nudibranch" dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", - "dataset": "nautilus", - + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1485,22 +1621,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - + "project": "clam", + "location": "whelk", + "model": "octopus", } path = MigrationServiceClient.model_path(**expected) @@ -1508,22 +1646,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", - + "project": "mussel", + "location": "winkle", + "model": "nautilus", } path = MigrationServiceClient.model_path(**expected) @@ -1531,22 +1671,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", - + "project": "clam", + "model": "whelk", + "version": "octopus", } path = MigrationServiceClient.version_path(**expected) @@ -1554,18 +1696,20 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", - + "billing_account": "nudibranch", } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1573,18 +1717,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", - + "folder": "mussel", } path = MigrationServiceClient.common_folder_path(**expected) @@ -1592,18 +1736,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", - + "organization": "nautilus", } path = MigrationServiceClient.common_organization_path(**expected) @@ -1611,18 +1755,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", - + "project": "abalone", } path = MigrationServiceClient.common_project_path(**expected) @@ -1630,20 +1774,22 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", - + "project": "whelk", + "location": "octopus", } path = MigrationServiceClient.common_location_path(**expected) @@ -1655,17 +1801,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) From 1ad9ee859addaa44129c094551fc755626d94912 Mon Sep 17 00:00:00 2001 From: jialuzh <35091833+jialuzh@users.noreply.github.com> Date: Tue, 20 Apr 2021 09:53:27 -0700 Subject: [PATCH 14/36] feat: implement get_experiment and get_pipeline methods for Metadata service (#326) --- google/cloud/aiplatform/__init__.py | 4 +- google/cloud/aiplatform/metadata/artifact.py | 30 ++ google/cloud/aiplatform/metadata/constants.py | 2 + google/cloud/aiplatform/metadata/context.py | 59 +++- google/cloud/aiplatform/metadata/execution.py | 95 ++++-- google/cloud/aiplatform/metadata/metadata.py | 204 ++++++++++++- google/cloud/aiplatform/metadata/resource.py | 120 ++++++-- setup.py | 1 + tests/unit/aiplatform/test_metadata.py | 253 +++++++++++++++- .../aiplatform/test_metadata_resources.py | 270 +++++++++++++++--- 10 files changed, 927 insertions(+), 111 deletions(-) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 6d1a197766..49a301db1e 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -51,7 +51,7 @@ log_metrics = metadata.metadata_service.log_metrics get_experiment = metadata.metadata_service.get_experiment get_pipeline = metadata.metadata_service.get_pipeline -set_run = metadata.metadata_service.set_run +start_run = metadata.metadata_service.start_run __all__ = ( @@ -62,7 +62,7 @@ "log_metrics", "get_experiment", "get_pipeline", - "set_run", + "start_run", "AutoMLImageTrainingJob", "AutoMLTabularTrainingJob", "AutoMLTextTrainingJob", diff --git a/google/cloud/aiplatform/metadata/artifact.py b/google/cloud/aiplatform/metadata/artifact.py index eb835cafd2..98eefacc5f 100644 --- a/google/cloud/aiplatform/metadata/artifact.py +++ b/google/cloud/aiplatform/metadata/artifact.py @@ -21,6 +21,7 @@ from google.cloud.aiplatform import utils from google.cloud.aiplatform.metadata.resource import _Resource +from google.cloud.aiplatform_v1beta1 import ListArtifactsRequest from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact @@ -57,4 +58,33 @@ def _create_resource( def _update_resource( cls, client: utils.MetadataClientWithOverride, resource: proto.Message, ) -> proto.Message: + """Update Artifacts with given input. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + resource (proto.Message): + Required. The proto.Message which contains the update information for the resource. + """ + return client.update_artifact(artifact=resource) + + @classmethod + def _list_resources( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + filter: Optional[str] = None, + ): + """List artifacts in the parent path that matches the filter. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + parent (str): + Required. The path where Artifacts are stored. + filter (str): + Optional. filter string to restrict the list result + """ + list_request = ListArtifactsRequest(parent=parent, filter=filter,) + return client.list_artifacts(request=list_request) diff --git a/google/cloud/aiplatform/metadata/constants.py b/google/cloud/aiplatform/metadata/constants.py index bee9677cd5..7db87cc222 100644 --- a/google/cloud/aiplatform/metadata/constants.py +++ b/google/cloud/aiplatform/metadata/constants.py @@ -17,6 +17,7 @@ SYSTEM_RUN = "system.Run" SYSTEM_EXPERIMENT = "system.Experiment" +SYSTEM_PIPELINE = "system.Pipeline" SYSTEM_METRICS = "system.Metrics" _DEFAULT_SCHEMA_VERSION = "0.0.1" @@ -24,6 +25,7 @@ SCHEMA_VERSIONS = { SYSTEM_RUN: _DEFAULT_SCHEMA_VERSION, SYSTEM_EXPERIMENT: _DEFAULT_SCHEMA_VERSION, + SYSTEM_PIPELINE: _DEFAULT_SCHEMA_VERSION, SYSTEM_METRICS: _DEFAULT_SCHEMA_VERSION, } diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py index 76e6283d51..cb3340499b 100644 --- a/google/cloud/aiplatform/metadata/context.py +++ b/google/cloud/aiplatform/metadata/context.py @@ -21,6 +21,7 @@ from google.cloud.aiplatform import utils from google.cloud.aiplatform.metadata.resource import _Resource +from google.cloud.aiplatform_v1beta1 import ListContextsRequest from google.cloud.aiplatform_v1beta1.types import context as gca_context @@ -30,6 +31,25 @@ class _Context(_Resource): _resource_noun = "contexts" _getter_method = "get_context" + def add_artifacts_and_executions( + self, + artifact_resource_names: Optional[Sequence[str]] = None, + execution_resource_names: Optional[Sequence[str]] = None, + ): + """Associate Executions and attribute Artifacts to a given Context. + + Args: + artifact_resource_names (Sequence[str]): + Optional. The full resource name of Artifacts to attribute to the Context. + execution_resource_names (Sequence[str]): + Optional. The full resource name of Executions to associate with the Context. + """ + self.api_client.add_context_artifacts_and_executions( + context=self.resource_name, + artifacts=artifact_resource_names, + executions=execution_resource_names, + ) + @classmethod def _create_resource( cls, @@ -57,23 +77,34 @@ def _create_resource( def _update_resource( cls, client: utils.MetadataClientWithOverride, resource: proto.Message, ) -> proto.Message: + """Update Contexts with given input. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + resource (proto.Message): + Required. The proto.Message which contains the update information for the resource. + """ + return client.update_context(context=resource) - def add_artifacts_and_executions( - self, - artifact_resource_names: Optional[Sequence[str]] = None, - execution_resource_names: Optional[Sequence[str]] = None, + @classmethod + def _list_resources( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + filter: Optional[str] = None, ): - """Creates a new Metadata resource. + """List Contexts in the parent path that matches the filter. Args: - artifact_resource_names (Sequence[str]): - Optional. The full resource name of Artifacts to attribute to the Context. - execution_resource_names (Sequence[str]): - Optional. The full resource name of Executions to associate with the Context. + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + parent (str): + Required. The path where Contexts are stored. + filter (str): + Optional. filter string to restrict the list result """ - self.api_client.add_context_artifacts_and_executions( - context=self.resource_name, - artifacts=artifact_resource_names, - executions=execution_resource_names, - ) + + list_request = ListContextsRequest(parent=parent, filter=filter,) + return client.list_contexts(request=list_request) diff --git a/google/cloud/aiplatform/metadata/execution.py b/google/cloud/aiplatform/metadata/execution.py index 6a173e07dd..39fc7a74b3 100644 --- a/google/cloud/aiplatform/metadata/execution.py +++ b/google/cloud/aiplatform/metadata/execution.py @@ -15,14 +15,17 @@ # limitations under the License. # -from typing import Optional, Dict +from typing import Optional, Dict, Sequence import proto +from google.api_core import exceptions from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata.artifact import _Artifact from google.cloud.aiplatform.metadata.resource import _Resource from google.cloud.aiplatform_v1beta1 import Event from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types.metadata_service import ListExecutionsRequest class _Execution(_Resource): @@ -31,6 +34,51 @@ class _Execution(_Resource): _resource_noun = "executions" _getter_method = "get_execution" + def add_artifact( + self, artifact_resource_name: str, input: bool, + ): + """Connect Artifact to a given Execution. + + Args: + artifact_resource_name (str): + Required. The full resource name of the Artifact to connect to the Execution through an Event. + input (bool) + Required. Whether Artifact is an input event to the Execution or not. + """ + + event = Event( + artifact=artifact_resource_name, + type_=Event.Type.INPUT if input else Event.Type.OUTPUT, + ) + + self.api_client.add_execution_events( + execution=self.resource_name, events=[event], + ) + + def query_input_and_output_artifacts(self) -> Sequence[_Artifact]: + """query the input and output artifacts connected to the execution. + + Returns: + A Sequence of _Artifacts + """ + + try: + artifacts = self.api_client.query_execution_inputs_and_outputs( + execution=self.resource_name + ).artifacts + except exceptions.NotFound: + return [] + + return [ + _Artifact( + resource=artifact, + project=self.project, + location=self.location, + credentials=self.credentials, + ) + for artifact in artifacts + ] + @classmethod def _create_resource( cls, @@ -54,29 +102,38 @@ def _create_resource( parent=parent, execution=gapic_execution, execution_id=resource_id, ) + @classmethod + def _list_resources( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + filter: Optional[str] = None, + ): + """List Executions in the parent path that matches the filter. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + parent (str): + Required. The path where Executions are stored. + filter (str): + Optional. filter string to restrict the list result + """ + + list_request = ListExecutionsRequest(parent=parent, filter=filter,) + return client.list_executions(request=list_request) + @classmethod def _update_resource( cls, client: utils.MetadataClientWithOverride, resource: proto.Message, ) -> proto.Message: - return client.update_execution(execution=resource) - - def add_artifact( - self, artifact_resource_name: str, input: bool, - ): - """Creates a new Metadata resource. + """Update Executions with given input. Args: - artifact_resource_name (str): - Required. The full resource name of the Artifact to connect to the Execution through an Event. - input (bool) - Required. Whether Artifact is an input event to the Execution or not. + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + resource (proto.Message): + Required. The proto.Message which contains the update information for the resource. """ - event = Event( - artifact=artifact_resource_name, - type_=Event.Type.INPUT if input else Event.Type.OUTPUT, - ) - - self.api_client.add_execution_events( - execution=self.resource_name, events=[event], - ) + return client.update_execution(execution=resource) diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index c1b76d04d0..c0314e921f 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Dict, Union +from typing import Dict, Union, Optional from google.cloud.aiplatform.metadata import constants from google.cloud.aiplatform.metadata.artifact import _Artifact @@ -43,11 +43,18 @@ def set_experiment(self, experiment: str): ) self._experiment = context - def set_run(self, run: str): + def start_run(self, run: str): + """Setup a run to current session. + + Args: + run (str): + Required. Name of the run to assign current session with. + """ + if not self._experiment: raise ValueError( "No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') " - "before trying to set_run. " + "before trying to start_run. " ) run_execution_id = f"{self._experiment.name}-{run}" run_execution = _Execution.get_or_create( @@ -75,6 +82,13 @@ def set_run(self, run: str): self._metrics = metrics_artifact def log_params(self, params: Dict[str, Union[float, int, str]]): + """Log single or multiple parameters with specified key and value pairs. + + Args: + params (Dict): + Required. Parameter key/value pairs. + """ + self._validate_experiment_and_run(method_name="log_params") # query the latest run execution resource before logging. execution = _Execution.get_or_create( @@ -85,6 +99,13 @@ def log_params(self, params: Dict[str, Union[float, int, str]]): execution.update(metadata=params) def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): + """Log single or multiple Metrics with specified key and value pairs. + + Args: + metrics (Dict): + Required. Metrics key/value pairs. + """ + self._validate_experiment_and_run(method_name="log_metrics") # query the latest metrics artifact resource before logging. artifact = _Artifact.get_or_create( @@ -94,11 +115,80 @@ def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): ) artifact.update(metadata=metrics) - def get_experiment(self, experiment: str): - raise NotImplementedError("get_experiment not implemented") + def get_experiment( + self, experiment: Optional[str] = None + ) -> "pd.DataFrame": # noqa: F821 + """Returns a Pandas DataFrame of the parameters and metrics associated with one experiment. + + Example: + + aiplatform.init(experiment='exp-1') + aiplatform.start_run(run='run-1') + aiplatform.log_params({'learning_rate': 0.1}) + aiplatform.log_metrics({'accuracy': 0.9}) + + aiplatform.start_run(run='run-2') + aiplatform.log_params({'learning_rate': 0.2}) + aiplatform.log_metrics({'accuracy': 0.95}) + + Will result in the following DataFrame + ___________________________________________________________________________ + | experiment_name | run_name | param.learning_rate | metric.accuracy | + --------------------------------------------------------------------------- + | exp-1 | run-1 | 0.1 | 0.9 | + | exp-1 | run-2 | 0.2 | 0.95 | + --------------------------------------------------------------------------- + + Args: + experiment (str): + Name of the Experiment to filter results. If not set, return results of current active experiment. + + Returns: + Pandas Dataframe of Experiment with metrics and parameters. + + Raise: + NotFound exception if experiment does not exist. + ValueError if given experiment is not associated with a wrong schema. + """ + + if not experiment: + experiment = self._experiment + + source = "experiment" + experiment_resource_name = self._get_experiment_or_pipeline_resource_name( + name=experiment, source=source, expected_schema=constants.SYSTEM_EXPERIMENT, + ) + + return self._query_runs_to_data_frame( + context_id=experiment, + context_resource_name=experiment_resource_name, + source=source, + ) + + def get_pipeline(self, pipeline: str) -> "pd.DataFrame": # noqa: F821 + """Returns a Pandas DataFrame of the parameters and metrics associated with one pipeline. + + Args: + pipeline: Name of the Pipeline to filter results. + + Returns: + Pandas Dataframe of Pipeline with metrics and parameters. + + Raise: + NotFound exception if experiment does not exist. + ValueError if given experiment is not associated with a wrong schema. + """ + + source = "pipeline" + pipeline_resource_name = self._get_experiment_or_pipeline_resource_name( + name=pipeline, source=source, expected_schema=constants.SYSTEM_PIPELINE + ) - def get_pipeline(self, pipeline: str): - raise NotImplementedError("get_pipeline not implemented") + return self._query_runs_to_data_frame( + context_id=pipeline, + context_resource_name=pipeline_resource_name, + source=source, + ) def _validate_experiment_and_run(self, method_name: str): if not self._experiment: @@ -108,8 +198,106 @@ def _validate_experiment_and_run(self, method_name: str): ) if not self._run: raise ValueError( - f"No run set. Make sure to call aiplatform.set_run('my-run') before trying to {method_name}. " + f"No run set. Make sure to call aiplatform.start_run('my-run') before trying to {method_name}. " + ) + + @staticmethod + def _get_experiment_or_pipeline_resource_name( + name: str, source: str, expected_schema: str + ) -> str: + """Get the full resource name of the Context representing an Experiment or Pipeline. + + Args: + name (str): + Name of the Experiment or Pipeline. + source (str): + Identify whether the this is an Experiment or a Pipeline. + expected_schema (str): + expected_schema identifies the expected schema used for Experiment or Pipeline. + + Returns: + The full resource name of the Experiment or Pipeline Context. + + Raise: + NotFound exception if experiment or pipeline does not exist. + """ + + context = _Context(resource_name=name) + + if context.schema_title != expected_schema: + raise ValueError( + f"Please provide a valid {source} name. {name} is not a {source}." + ) + return context.resource_name + + def _query_runs_to_data_frame( + self, context_id: str, context_resource_name: str, source: str + ) -> "pd.DataFrame": # noqa: F821 + """Get metrics and parameters associated with a given Context into a Dataframe. + + Args: + context_id (str): + Name of the Experiment or Pipeline. + context_resource_name (str): + Full resource name of the Context associated with an Experiment or Pipeline. + source (str): + Identify whether the this is an Experiment or a Pipeline. + + Returns: + The full resource name of the Experiment or Pipeline Context. + """ + + filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{context_resource_name}")' + run_executions = _Execution.list(filter=filter) + + context_summary = [] + for run_execution in run_executions: + run_dict = { + f"{source}_name": context_id, + "run_name": run_execution.display_name, + } + run_dict.update( + self._execution_to_column_named_metadata( + "param", run_execution.metadata + ) + ) + + for metric_artifact in run_execution.query_input_and_output_artifacts(): + run_dict.update( + self._execution_to_column_named_metadata( + "metric", metric_artifact.metadata + ) + ) + + context_summary.append(run_dict) + + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to get dataframe as the return format. " + 'Please install the SDK using "pip install python-aiplatform[full]"' ) + return pd.DataFrame(context_summary) + + @staticmethod + def _execution_to_column_named_metadata( + metadata_type: str, metadata: Dict, + ) -> Dict[str, Union[int, float, str]]: + """Returns a dict of the Execution/Artifact metadata with column names. + + Args: + metadata_type: The type of this execution properties (param, metric). + metadata: Either an Execution or Artifact metadata field. + + Returns: + Dict of custom properties with keys mapped to column names + """ + + return { + ".".join([metadata_type, key]): value for key, value in metadata.items() + } + metadata_service = _MetadataService() diff --git a/google/cloud/aiplatform/metadata/resource.py b/google/cloud/aiplatform/metadata/resource.py index 022e51e7e8..03266bafe3 100644 --- a/google/cloud/aiplatform/metadata/resource.py +++ b/google/cloud/aiplatform/metadata/resource.py @@ -19,14 +19,18 @@ import logging import re from copy import deepcopy -from typing import Optional, Dict +from typing import Optional, Dict, Union, Sequence import proto from google.api_core import exceptions from google.auth import credentials as auth_credentials +from google.protobuf import json_format from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils +from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact +from google.cloud.aiplatform_v1beta1 import Context as GapicContext +from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution class _Resource(base.AiPlatformResourceNounWithFutureManager, abc.ABC): @@ -38,8 +42,9 @@ class _Resource(base.AiPlatformResourceNounWithFutureManager, abc.ABC): def __init__( self, - resource_name: str, - metadata_store_id: Optional[str] = "default", + resource_name: Optional[str] = None, + resource: Optional[Union[GapicContext, GapicArtifact, GapicExecution]] = None, + metadata_store_id: str = "default", project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, @@ -50,7 +55,11 @@ def __init__( resource_name (str): A fully-qualified resource name or ID Example: "projects/123/locations/us-central1/metadataStores/default//my-resource". - or "my-resource" when project and location are initialized or passed. + or "my-resource" when project and location are initialized or passed. if ``resource`` is provided, this + should not be set. + resource (Union[GapicContext, GapicArtifact, GapicExecution]): + The proto.Message that contains the full information of the resource. If both set, this field overrides + ``resource_name`` field. metadata_store_id (str): MetadataStore to retrieve resource from. If not set, metadata_store_id is set to "default". If resource_name is a fully-qualified resource, its metadata_store_id overrides this one. @@ -69,23 +78,32 @@ def __init__( project=project, location=location, credentials=credentials, ) - # If we receive a full resource name, we extract the metadata_store_id and use that - if "/" in resource_name: - metadata_store_id = _Resource._extract_metadata_store_id( - resource_name, self._resource_noun - ) + if resource: + self._gca_resource = resource + return - full_resource_name = utils.full_resource_name( - resource_name=resource_name, - resource_noun=f"metadataStores/{metadata_store_id}/{self._resource_noun}", - project=self.project, - location=self.location, - ) + full_resource_name = resource_name + # Construct the full_resource_name if input resource_name is the resource_id + if "/" not in resource_name: + full_resource_name = utils.full_resource_name( + resource_name=resource_name, + resource_noun=f"metadataStores/{metadata_store_id}/{self._resource_noun}", + project=self.project, + location=self.location, + ) self._gca_resource = getattr(self.api_client, self._getter_method)( name=full_resource_name ) + @property + def metadata(self) -> Dict: + return json_format.MessageToDict(self._gca_resource._pb)["metadata"] + + @property + def schema_title(self) -> str: + return self._gca_resource.schema_title + @classmethod def get_or_create( cls, @@ -95,7 +113,7 @@ def get_or_create( schema_version: Optional[str] = None, description: Optional[str] = None, metadata: Optional[Dict] = None, - metadata_store_id: Optional[str] = "default", + metadata_store_id: str = "default", project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, @@ -189,6 +207,70 @@ def update( ) self._gca_resource = update_gca_resource + @classmethod + def list( + cls, + filter: Optional[str] = None, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Sequence["_Resource"]: + """List Metadata resources that match the list filter in target metadataStore. + + Args: + filter (str): + Optional. A query to filter available resources for + matching results. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resources (sequence[_Resource]): + a list of managed Metadata resource. + + """ + api_client = cls._instantiate_client(location=location, credentials=credentials) + + parent = ( + initializer.global_config.common_location_path( + project=project, location=location + ) + + f"/metadataStores/{metadata_store_id}" + ) + + try: + resources = cls._list_resources( + client=api_client, parent=parent, filter=filter, + ) + except exceptions.NotFound: + logging.info( + f"No matching resources in metadataStore: {metadata_store_id} with filter: {filter}" + ) + return [] + + return [ + cls( + resource=resource, + project=project, + location=location, + credentials=credentials, + ) + for resource in resources + ] + @classmethod def _create( cls, @@ -251,7 +333,7 @@ def _create( ) try: - cls._create_resource( + resource = cls._create_resource( client=api_client, parent=parent, resource_id=resource_id, @@ -263,10 +345,10 @@ def _create( ) except exceptions.AlreadyExists: logging.info(f"Resource '{resource_id}' already exist") + return return cls( - resource_name=f"{parent}/{cls._resource_noun}/{resource_id}", - metadata_store_id=metadata_store_id, + resource=resource, project=project, location=location, credentials=credentials, diff --git a/setup.py b/setup.py index b89d2a6417..84bb3c75d9 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ "google-cloud-storage >= 1.26.0, < 2.0.0dev", "google-cloud-bigquery >= 1.15.0, < 3.0.0dev", ), + extras_require={"full": ["pandas>=1.0.0"]}, python_requires=">=3.6", scripts=[], classifiers=[ diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py index cae78d85ee..a3ad97f2c8 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata.py @@ -16,9 +16,10 @@ # from importlib import reload -from unittest.mock import patch +from unittest.mock import patch, call import pytest +from google.api_core import exceptions from google.cloud import aiplatform from google.cloud.aiplatform import initializer @@ -27,6 +28,8 @@ from google.cloud.aiplatform_v1beta1 import ( AddContextArtifactsAndExecutionsResponse, Event, + LineageSubgraph, + ListExecutionsRequest, ) from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact from google.cloud.aiplatform_v1beta1 import Context as GapicContext @@ -38,13 +41,16 @@ from google.cloud.aiplatform_v1beta1 import MetadataStore as GapicMetadataStore # project + _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" _TEST_PARENT = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" ) _TEST_EXPERIMENT = "test-experiment" -_TEST_RUN = "run" +_TEST_PIPELINE = _TEST_EXPERIMENT +_TEST_RUN = "run-1" +_TEST_OTHER_RUN = "run-2" # resource attributes _TEST_METADATA = {"test-param1": 1, "test-param2": "test-value", "test-param3": True} @@ -61,16 +67,26 @@ # execution _TEST_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}" _TEST_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_EXECUTION_ID}" +_TEST_OTHER_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_OTHER_RUN}" +_TEST_OTHER_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_OTHER_EXECUTION_ID}" # artifact _TEST_ARTIFACT_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}-metrics" _TEST_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_ARTIFACT_ID}" +_TEST_OTHER_ARTIFACT_ID = f"{_TEST_EXPERIMENT}-{_TEST_OTHER_RUN}-metrics" +_TEST_OTHER_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_OTHER_ARTIFACT_ID}" # parameters -_TEST_PARAMS = {"learning_rate": 0.01, "dropout": 0.2} +_TEST_PARAM_KEY_1 = "learning_rate" +_TEST_PARAM_KEY_2 = "dropout" +_TEST_PARAMS = {_TEST_PARAM_KEY_1: 0.01, _TEST_PARAM_KEY_2: 0.2} +_TEST_OTHER_PARAMS = {_TEST_PARAM_KEY_1: 0.02, _TEST_PARAM_KEY_2: 0.3} # metrics -_TEST_METRICS = {"rmse": 222, "accuracy": 1} +_TEST_METRIC_KEY_1 = "rmse" +_TEST_METRIC_KEY_2 = "accuracy" +_TEST_METRICS = {_TEST_METRIC_KEY_1: 222, _TEST_METRIC_KEY_2: 1} +_TEST_OTHER_METRICS = {_TEST_METRIC_KEY_2: 0.9} @pytest.fixture @@ -97,6 +113,30 @@ def get_context_mock(): yield get_context_mock +@pytest.fixture +def get_pipeline_context_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_pipeline_context_mock: + get_pipeline_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + schema_title=constants.SYSTEM_PIPELINE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_PIPELINE], + metadata=constants.EXPERIMENT_METADATA, + ) + yield get_pipeline_context_mock + + +@pytest.fixture +def get_context_not_found_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_context_not_found_mock: + get_context_not_found_mock.side_effect = exceptions.NotFound("test: not found") + yield get_context_not_found_mock + + @pytest.fixture def add_context_artifacts_and_executions_mock(): with patch.object( @@ -144,6 +184,64 @@ def add_execution_events_mock(): yield add_execution_events_mock +@pytest.fixture +def list_executions_mock(): + with patch.object(MetadataServiceClient, "list_executions") as list_executions_mock: + list_executions_mock.return_value = [ + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_PARAMS, + ), + GapicExecution( + name=_TEST_OTHER_EXECUTION_NAME, + display_name=_TEST_OTHER_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_OTHER_PARAMS, + ), + ] + yield list_executions_mock + + +@pytest.fixture +def query_execution_inputs_and_outputs_mock(): + with patch.object( + MetadataServiceClient, "query_execution_inputs_and_outputs" + ) as query_execution_inputs_and_outputs_mock: + query_execution_inputs_and_outputs_mock.side_effect = [ + LineageSubgraph( + artifacts=[ + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[ + constants.SYSTEM_METRICS + ], + metadata=_TEST_METRICS, + ), + ], + ), + LineageSubgraph( + artifacts=[ + GapicArtifact( + name=_TEST_OTHER_ARTIFACT_NAME, + display_name=_TEST_OTHER_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[ + constants.SYSTEM_METRICS + ], + metadata=_TEST_OTHER_METRICS, + ), + ], + ), + ] + yield query_execution_inputs_and_outputs_mock + + @pytest.fixture def get_artifact_mock(): with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock: @@ -169,6 +267,20 @@ def update_artifact_mock(): yield update_artifact_mock +def _assert_frame_equal_with_sorted_columns(dataframe_1, dataframe_2): + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to test the get_experiment/pipeline method. " + 'Please install the SDK using "pip install python-aiplatform[full]"' + ) + + pd.testing.assert_frame_equal( + dataframe_1.sort_index(axis=1), dataframe_2.sort_index(axis=1), check_names=True + ) + + class TestMetadata: def setup_method(self): reload(initializer) @@ -190,7 +302,7 @@ def test_init_experiment_with_existing_metadataStore_and_context( @pytest.mark.usefixtures("get_metadata_store_mock") @pytest.mark.usefixtures("get_context_mock") - def test_set_run_with_existing_execution_and_artifact( + def test_start_run_with_existing_execution_and_artifact( self, get_execution_mock, add_context_artifacts_and_executions_mock, @@ -200,7 +312,7 @@ def test_set_run_with_existing_execution_and_artifact( aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT ) - aiplatform.set_run(_TEST_RUN) + aiplatform.start_run(_TEST_RUN) get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) add_context_artifacts_and_executions_mock.assert_called_once_with( @@ -226,7 +338,7 @@ def test_log_params( aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT ) - aiplatform.set_run(_TEST_RUN) + aiplatform.start_run(_TEST_RUN) aiplatform.log_params(_TEST_PARAMS) updated_execution = GapicExecution( @@ -251,7 +363,7 @@ def test_log_metrics( aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT ) - aiplatform.set_run(_TEST_RUN) + aiplatform.start_run(_TEST_RUN) aiplatform.log_metrics(_TEST_METRICS) updated_artifact = GapicArtifact( @@ -263,3 +375,128 @@ def test_log_metrics( ) update_artifact_mock.assert_called_once_with(artifact=updated_artifact) + + # TODO: remove skip once koroko test would install extra required packages. + @pytest.mark.skip( + reason="Temporarily skip this test as extra required package are not installed in current setup" + ) + @pytest.mark.usefixtures("get_context_mock") + def test_get_experiment( + self, list_executions_mock, query_execution_inputs_and_outputs_mock + ): + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to test the get_experiment method. " + 'Please install the SDK using "pip install python-aiplatform[full]"' + ) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + experiment_df = aiplatform.get_experiment(_TEST_EXPERIMENT) + + expected_filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{_TEST_CONTEXT_NAME}")' + list_executions_mock.assert_called_once_with( + request=ListExecutionsRequest(parent=_TEST_PARENT, filter=expected_filter,) + ) + query_execution_inputs_and_outputs_mock.assert_has_calls( + [ + call(execution=_TEST_EXECUTION_NAME), + call(execution=_TEST_OTHER_EXECUTION_NAME), + ] + ) + experiment_df_truth = pd.DataFrame( + [ + { + "experiment_name": _TEST_EXPERIMENT, + "run_name": _TEST_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.01, + "param.%s" % _TEST_PARAM_KEY_2: 0.2, + "metric.%s" % _TEST_METRIC_KEY_1: 222, + "metric.%s" % _TEST_METRIC_KEY_2: 1, + }, + { + "experiment_name": _TEST_EXPERIMENT, + "run_name": _TEST_OTHER_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.02, + "param.%s" % _TEST_PARAM_KEY_2: 0.3, + "metric.%s" % _TEST_METRIC_KEY_2: 0.9, + }, + ] + ) + + _assert_frame_equal_with_sorted_columns(experiment_df, experiment_df_truth) + + @pytest.mark.usefixtures("get_context_not_found_mock") + def test_get_experiment_not_exist(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(exceptions.NotFound): + aiplatform.get_experiment(_TEST_EXPERIMENT) + + @pytest.mark.usefixtures("get_pipeline_context_mock") + def test_get_experiment_wrong_schema(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(ValueError): + aiplatform.get_experiment(_TEST_EXPERIMENT) + + @pytest.mark.skip( + reason="Temporarily skip this test as extra required package are not installed in current setup" + ) + @pytest.mark.usefixtures("get_pipeline_context_mock") + def test_get_pipeline( + self, list_executions_mock, query_execution_inputs_and_outputs_mock + ): + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to test the get_pipeline method. " + 'Please install the SDK using "pip install python-aiplatform[full]"' + ) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + pipeline_df = aiplatform.get_pipeline(_TEST_PIPELINE) + + expected_filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{_TEST_CONTEXT_NAME}")' + list_executions_mock.assert_called_once_with( + request=ListExecutionsRequest(parent=_TEST_PARENT, filter=expected_filter,) + ) + query_execution_inputs_and_outputs_mock.assert_has_calls( + [ + call(execution=_TEST_EXECUTION_NAME), + call(execution=_TEST_OTHER_EXECUTION_NAME), + ] + ) + pipeline_df_truth = pd.DataFrame( + [ + { + "pipeline_name": _TEST_PIPELINE, + "run_name": _TEST_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.01, + "param.%s" % _TEST_PARAM_KEY_2: 0.2, + "metric.%s" % _TEST_METRIC_KEY_1: 222, + "metric.%s" % _TEST_METRIC_KEY_2: 1, + }, + { + "pipeline_name": _TEST_PIPELINE, + "run_name": _TEST_OTHER_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.02, + "param.%s" % _TEST_PARAM_KEY_2: 0.3, + "metric.%s" % _TEST_METRIC_KEY_2: 0.9, + }, + ] + ) + + _assert_frame_equal_with_sorted_columns(pipeline_df, pipeline_df_truth) + + @pytest.mark.usefixtures("get_context_not_found_mock") + def test_get_pipeline_not_exist(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(exceptions.NotFound): + aiplatform.get_pipeline(_TEST_PIPELINE) + + @pytest.mark.usefixtures("get_context_mock") + def test_get_pipeline_wrong_schema(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(ValueError): + aiplatform.get_pipeline(_TEST_PIPELINE) diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py index a46860f37c..19258aef3c 100644 --- a/tests/unit/aiplatform/test_metadata_resources.py +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -16,7 +16,7 @@ # from importlib import reload -from unittest.mock import patch, call +from unittest.mock import patch import pytest from google.api_core import exceptions @@ -30,10 +30,14 @@ from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact from google.cloud.aiplatform_v1beta1 import Context as GapicContext from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution +from google.cloud.aiplatform_v1beta1 import LineageSubgraph from google.cloud.aiplatform_v1beta1 import ( MetadataServiceClient, AddExecutionEventsResponse, Event, + ListExecutionsRequest, + ListArtifactsRequest, + ListContextsRequest, ) # project @@ -87,17 +91,9 @@ def get_context_for_get_or_create_mock(): with patch.object( MetadataServiceClient, "get_context" ) as get_context_for_get_or_create_mock: - get_context_for_get_or_create_mock.side_effect = [ - exceptions.NotFound("test: Context Not Found"), - GapicContext( - name=_TEST_CONTEXT_NAME, - display_name=_TEST_DISPLAY_NAME, - schema_title=_TEST_SCHEMA_TITLE, - schema_version=_TEST_SCHEMA_VERSION, - description=_TEST_DESCRIPTION, - metadata=_TEST_METADATA, - ), - ] + get_context_for_get_or_create_mock.side_effect = exceptions.NotFound( + "test: Context Not Found" + ) yield get_context_for_get_or_create_mock @@ -115,6 +111,30 @@ def create_context_mock(): yield create_context_mock +@pytest.fixture +def list_contexts_mock(): + with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock: + list_contexts_mock.return_value = [ + GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield list_contexts_mock + + @pytest.fixture def update_context_mock(): with patch.object(MetadataServiceClient, "update_context") as update_context_mock: @@ -159,17 +179,9 @@ def get_execution_for_get_or_create_mock(): with patch.object( MetadataServiceClient, "get_execution" ) as get_execution_for_get_or_create_mock: - get_execution_for_get_or_create_mock.side_effect = [ - exceptions.NotFound("test: Execution Not Found"), - GapicExecution( - name=_TEST_EXECUTION_NAME, - display_name=_TEST_DISPLAY_NAME, - schema_title=_TEST_SCHEMA_TITLE, - schema_version=_TEST_SCHEMA_VERSION, - description=_TEST_DESCRIPTION, - metadata=_TEST_METADATA, - ), - ] + get_execution_for_get_or_create_mock.side_effect = exceptions.NotFound( + "test: Execution Not Found" + ) yield get_execution_for_get_or_create_mock @@ -189,6 +201,58 @@ def create_execution_mock(): yield create_execution_mock +@pytest.fixture +def list_executions_mock(): + with patch.object(MetadataServiceClient, "list_executions") as list_executions_mock: + list_executions_mock.return_value = [ + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield list_executions_mock + + +@pytest.fixture +def query_execution_inputs_and_outputs_mock(): + with patch.object( + MetadataServiceClient, "query_execution_inputs_and_outputs" + ) as query_execution_inputs_and_outputs_mock: + query_execution_inputs_and_outputs_mock.return_value = LineageSubgraph( + artifacts=[ + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ], + ) + yield query_execution_inputs_and_outputs_mock + + @pytest.fixture def update_execution_mock(): with patch.object( @@ -233,17 +297,9 @@ def get_artifact_for_get_or_create_mock(): with patch.object( MetadataServiceClient, "get_artifact" ) as get_artifact_for_get_or_create_mock: - get_artifact_for_get_or_create_mock.side_effect = [ - exceptions.NotFound("test: Artifact Not Found"), - GapicArtifact( - name=_TEST_ARTIFACT_NAME, - display_name=_TEST_DISPLAY_NAME, - schema_title=_TEST_SCHEMA_TITLE, - schema_version=_TEST_SCHEMA_VERSION, - description=_TEST_DESCRIPTION, - metadata=_TEST_METADATA, - ), - ] + get_artifact_for_get_or_create_mock.side_effect = exceptions.NotFound( + "test: Artifact Not Found" + ) yield get_artifact_for_get_or_create_mock @@ -261,6 +317,30 @@ def create_artifact_mock(): yield create_artifact_mock +@pytest.fixture +def list_artifacts_mock(): + with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock: + list_artifacts_mock.return_value = [ + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield list_artifacts_mock + + @pytest.fixture def update_artifact_mock(): with patch.object(MetadataServiceClient, "update_artifact") as update_artifact_mock: @@ -317,8 +397,8 @@ def test_get_or_create_context( description=_TEST_DESCRIPTION, metadata=_TEST_METADATA, ) - get_context_for_get_or_create_mock.assert_has_calls( - calls=[call(name=_TEST_CONTEXT_NAME), call(name=_TEST_CONTEXT_NAME)] + get_context_for_get_or_create_mock.assert_called_once_with( + name=_TEST_CONTEXT_NAME ) create_context_mock.assert_called_once_with( parent=_TEST_PARENT, context_id=_TEST_CONTEXT_ID, context=expected_context, @@ -352,9 +432,34 @@ def test_update_context(self, update_context_mock): metadata=_TEST_UPDATED_METADATA, ) - update_context_mock.assert_called_once_with(context=updated_context,) + update_context_mock.assert_called_once_with(context=updated_context) assert my_context._gca_resource == updated_context + @pytest.mark.usefixtures("get_context_mock") + def test_list_contexts(self, list_contexts_mock): + aiplatform.init(project=_TEST_PROJECT) + + filter = "test-filter" + context_list = context._Context.list( + filter=filter, metadata_store_id=_TEST_METADATA_STORE + ) + + expected_context = GapicContext( + name=_TEST_CONTEXT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + list_contexts_mock.assert_called_once_with( + request=ListContextsRequest(parent=_TEST_PARENT, filter=filter,) + ) + assert len(context_list) == 2 + assert context_list[0]._gca_resource == expected_context + assert context_list[1]._gca_resource == expected_context + @pytest.mark.usefixtures("get_context_mock") def test_add_artifacts_and_executions( self, add_context_artifacts_and_executions_mock @@ -468,8 +573,8 @@ def test_get_or_create_execution( description=_TEST_DESCRIPTION, metadata=_TEST_METADATA, ) - get_execution_for_get_or_create_mock.assert_has_calls( - calls=[call(name=_TEST_EXECUTION_NAME), call(name=_TEST_EXECUTION_NAME)] + get_execution_for_get_or_create_mock.assert_called_once_with( + name=_TEST_EXECUTION_NAME ) create_execution_mock.assert_called_once_with( parent=_TEST_PARENT, @@ -508,6 +613,31 @@ def test_update_execution(self, update_execution_mock): update_execution_mock.assert_called_once_with(execution=updated_execution) assert my_execution._gca_resource == updated_execution + @pytest.mark.usefixtures("get_execution_mock") + def test_list_executions(self, list_executions_mock): + aiplatform.init(project=_TEST_PROJECT) + + filter = "test-filter" + execution_list = execution._Execution.list( + filter=filter, metadata_store_id=_TEST_METADATA_STORE + ) + + expected_execution = GapicExecution( + name=_TEST_EXECUTION_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + list_executions_mock.assert_called_once_with( + request=ListExecutionsRequest(parent=_TEST_PARENT, filter=filter,) + ) + assert len(execution_list) == 2 + assert execution_list[0]._gca_resource == expected_execution + assert execution_list[1]._gca_resource == expected_execution + @pytest.mark.usefixtures("get_execution_mock") def test_add_artifact(self, add_execution_events_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) @@ -529,6 +659,39 @@ def test_add_artifact(self, add_execution_events_mock): events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], ) + @pytest.mark.usefixtures("get_execution_mock") + def test_query_input_and_output_artifacts( + self, query_execution_inputs_and_outputs_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_execution = execution._Execution.get_or_create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + artifact_list = my_execution.query_input_and_output_artifacts() + + expected_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + query_execution_inputs_and_outputs_mock.assert_called_once_with( + execution=_TEST_EXECUTION_NAME, + ) + assert len(artifact_list) == 2 + assert artifact_list[0]._gca_resource == expected_artifact + assert artifact_list[1]._gca_resource == expected_artifact + class TestArtifact: def setup_method(self): @@ -572,8 +735,8 @@ def test_get_or_create_artifact( description=_TEST_DESCRIPTION, metadata=_TEST_METADATA, ) - get_artifact_for_get_or_create_mock.assert_has_calls( - calls=[call(name=_TEST_ARTIFACT_NAME), call(name=_TEST_ARTIFACT_NAME)] + get_artifact_for_get_or_create_mock.assert_called_once_with( + name=_TEST_ARTIFACT_NAME ) create_artifact_mock.assert_called_once_with( parent=_TEST_PARENT, @@ -611,3 +774,28 @@ def test_update_artifact(self, update_artifact_mock): update_artifact_mock.assert_called_once_with(artifact=updated_artifact) assert my_artifact._gca_resource == updated_artifact + + @pytest.mark.usefixtures("get_artifact_mock") + def test_list_artifacts(self, list_artifacts_mock): + aiplatform.init(project=_TEST_PROJECT) + + filter = "test-filter" + artifact_list = artifact._Artifact.list( + filter=filter, metadata_store_id=_TEST_METADATA_STORE + ) + + expected_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + list_artifacts_mock.assert_called_once_with( + request=ListArtifactsRequest(parent=_TEST_PARENT, filter=filter,) + ) + assert len(artifact_list) == 2 + assert artifact_list[0]._gca_resource == expected_artifact + assert artifact_list[1]._gca_resource == expected_artifact From ab979da6d702f61b09c408777151bff5b7431143 Mon Sep 17 00:00:00 2001 From: jialuzh <35091833+jialuzh@users.noreply.github.com> Date: Thu, 22 Apr 2021 10:34:06 -0700 Subject: [PATCH 15/36] fix: rename get_experiment/pipeline to get_experiment_df and get_pipeline_df (#347) --- google/cloud/aiplatform/__init__.py | 8 +++--- google/cloud/aiplatform/metadata/metadata.py | 6 ++-- tests/unit/aiplatform/test_metadata.py | 30 ++++++++++---------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 49a301db1e..e25bc4470e 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -49,8 +49,8 @@ log_params = metadata.metadata_service.log_params log_metrics = metadata.metadata_service.log_metrics -get_experiment = metadata.metadata_service.get_experiment -get_pipeline = metadata.metadata_service.get_pipeline +get_experiment_df = metadata.metadata_service.get_experiment_df +get_pipeline_df = metadata.metadata_service.get_pipeline_df start_run = metadata.metadata_service.start_run @@ -60,8 +60,8 @@ "init", "log_params", "log_metrics", - "get_experiment", - "get_pipeline", + "get_experiment_df", + "get_pipeline_df", "start_run", "AutoMLImageTrainingJob", "AutoMLTabularTrainingJob", diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index c0314e921f..9eaeab0fdc 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -115,7 +115,7 @@ def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): ) artifact.update(metadata=metrics) - def get_experiment( + def get_experiment_df( self, experiment: Optional[str] = None ) -> "pd.DataFrame": # noqa: F821 """Returns a Pandas DataFrame of the parameters and metrics associated with one experiment. @@ -152,7 +152,7 @@ def get_experiment( """ if not experiment: - experiment = self._experiment + experiment = self._experiment.name source = "experiment" experiment_resource_name = self._get_experiment_or_pipeline_resource_name( @@ -165,7 +165,7 @@ def get_experiment( source=source, ) - def get_pipeline(self, pipeline: str) -> "pd.DataFrame": # noqa: F821 + def get_pipeline_df(self, pipeline: str) -> "pd.DataFrame": # noqa: F821 """Returns a Pandas DataFrame of the parameters and metrics associated with one pipeline. Args: diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py index a3ad97f2c8..871cb2f9b6 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata.py @@ -272,7 +272,7 @@ def _assert_frame_equal_with_sorted_columns(dataframe_1, dataframe_2): import pandas as pd except ImportError: raise ImportError( - "Pandas is not installed and is required to test the get_experiment/pipeline method. " + "Pandas is not installed and is required to test the get_experiment_df/pipeline_df method. " 'Please install the SDK using "pip install python-aiplatform[full]"' ) @@ -381,19 +381,19 @@ def test_log_metrics( reason="Temporarily skip this test as extra required package are not installed in current setup" ) @pytest.mark.usefixtures("get_context_mock") - def test_get_experiment( + def test_get_experiment_df( self, list_executions_mock, query_execution_inputs_and_outputs_mock ): try: import pandas as pd except ImportError: raise ImportError( - "Pandas is not installed and is required to test the get_experiment method. " + "Pandas is not installed and is required to test the get_experiment_df method. " 'Please install the SDK using "pip install python-aiplatform[full]"' ) aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - experiment_df = aiplatform.get_experiment(_TEST_EXPERIMENT) + experiment_df = aiplatform.get_experiment_df(_TEST_EXPERIMENT) expected_filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{_TEST_CONTEXT_NAME}")' list_executions_mock.assert_called_once_with( @@ -428,34 +428,34 @@ def test_get_experiment( _assert_frame_equal_with_sorted_columns(experiment_df, experiment_df_truth) @pytest.mark.usefixtures("get_context_not_found_mock") - def test_get_experiment_not_exist(self): + def test_get_experiment_df_not_exist(self): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with pytest.raises(exceptions.NotFound): - aiplatform.get_experiment(_TEST_EXPERIMENT) + aiplatform.get_experiment_df(_TEST_EXPERIMENT) @pytest.mark.usefixtures("get_pipeline_context_mock") - def test_get_experiment_wrong_schema(self): + def test_get_experiment_df_wrong_schema(self): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with pytest.raises(ValueError): - aiplatform.get_experiment(_TEST_EXPERIMENT) + aiplatform.get_experiment_df(_TEST_EXPERIMENT) @pytest.mark.skip( reason="Temporarily skip this test as extra required package are not installed in current setup" ) @pytest.mark.usefixtures("get_pipeline_context_mock") - def test_get_pipeline( + def test_get_pipeline_df( self, list_executions_mock, query_execution_inputs_and_outputs_mock ): try: import pandas as pd except ImportError: raise ImportError( - "Pandas is not installed and is required to test the get_pipeline method. " + "Pandas is not installed and is required to test the get_pipeline_df method. " 'Please install the SDK using "pip install python-aiplatform[full]"' ) aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - pipeline_df = aiplatform.get_pipeline(_TEST_PIPELINE) + pipeline_df = aiplatform.get_pipeline_df(_TEST_PIPELINE) expected_filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{_TEST_CONTEXT_NAME}")' list_executions_mock.assert_called_once_with( @@ -490,13 +490,13 @@ def test_get_pipeline( _assert_frame_equal_with_sorted_columns(pipeline_df, pipeline_df_truth) @pytest.mark.usefixtures("get_context_not_found_mock") - def test_get_pipeline_not_exist(self): + def test_get_pipeline_df_not_exist(self): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with pytest.raises(exceptions.NotFound): - aiplatform.get_pipeline(_TEST_PIPELINE) + aiplatform.get_pipeline_df(_TEST_PIPELINE) @pytest.mark.usefixtures("get_context_mock") - def test_get_pipeline_wrong_schema(self): + def test_get_pipeline_df_wrong_schema(self): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with pytest.raises(ValueError): - aiplatform.get_pipeline(_TEST_PIPELINE) + aiplatform.get_pipeline_df(_TEST_PIPELINE) From 97fcf8e00ed7c5f6435568a6cbc3bcb3f507dbf8 Mon Sep 17 00:00:00 2001 From: jialuzh <35091833+jialuzh@users.noreply.github.com> Date: Thu, 29 Apr 2021 08:53:16 -0700 Subject: [PATCH 16/36] feat: Add resource schema title validation on existing experiment/run. Also add metadata reset logics. (#354) --- google/cloud/aiplatform/initializer.py | 8 ++ google/cloud/aiplatform/metadata/metadata.py | 47 +++++++++ tests/unit/aiplatform/test_metadata.py | 103 ++++++++++++++++++- 3 files changed, 157 insertions(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index a07bf1e779..bbdf4d4aa9 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -82,6 +82,14 @@ def init( If set, this resource and all sub-resources will be secured by this key. """ + + # reset metadata_service config if project or location is updated. + if (project and project != self._project) or ( + location and location != self._location + ): + if metadata.metadata_service.experiment_name: + logging.info("project/location updated, reset Metadata config.") + metadata.metadata_service.reset() if project: self._project = project if location: diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index 9eaeab0fdc..e31350d466 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -32,7 +32,36 @@ def __init__(self): self._run = None self._metrics = None + def reset(self): + """Reset all _MetadataService fields to None""" + self._experiment = None + self._run = None + self._metrics = None + + @property + def experiment_name(self) -> Optional[str]: + """Return the experiment name of the _MetadataService, if experiment is not set, return None""" + if self._experiment: + return self._experiment.display_name + return None + + @property + def run_name(self) -> Optional[str]: + """Return the run name of the _MetadataService, if run is not set, return None""" + if self._run: + return self._run.display_name + return None + def set_experiment(self, experiment: str): + """Setup a experiment to current session. + + Args: + experiment (str): + Required. Name of the experiment to assign current session with. + Raises: + ValueError if a context with the same name as the experiment is create but with a different schema. + """ + _MetadataStore.get_or_create() context = _Context.get_or_create( resource_id=experiment, @@ -41,6 +70,11 @@ def set_experiment(self, experiment: str): schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], metadata=constants.EXPERIMENT_METADATA, ) + if context.schema_title != constants.SYSTEM_EXPERIMENT: + raise ValueError( + f"Experiment name {experiment} has been used to create other type of resources " + f"({context.schema_title}) in this MetadataStore, please choose a different experiment name." + ) self._experiment = context def start_run(self, run: str): @@ -49,6 +83,9 @@ def start_run(self, run: str): Args: run (str): Required. Name of the run to assign current session with. + Raise: + ValueError if experiment is not set. Or if run execution or metrics artifact + is already created but with a different schema. """ if not self._experiment: @@ -63,6 +100,11 @@ def start_run(self, run: str): schema_title=constants.SYSTEM_RUN, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], ) + if run_execution.schema_title != constants.SYSTEM_RUN: + raise ValueError( + f"Run name {run} has been used to create other type of resources ({run_execution.schema_title}) " + "in this MetadataStore, please choose a different run name." + ) self._experiment.add_artifacts_and_executions( execution_resource_names=[run_execution.resource_name] ) @@ -74,6 +116,11 @@ def start_run(self, run: str): schema_title=constants.SYSTEM_METRICS, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], ) + if metrics_artifact.schema_title != constants.SYSTEM_METRICS: + raise ValueError( + f"Run name {run} has been used to create other type of resources ({metrics_artifact.schema_title}) " + "in this MetadataStore, please choose a different run name." + ) run_execution.add_artifact( artifact_resource_name=metrics_artifact.resource_name, input=False ) diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py index 871cb2f9b6..bdb1c7bc39 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata.py @@ -43,6 +43,7 @@ # project _TEST_PROJECT = "test-project" +_TEST_OTHER_PROJECT = "test-project-1" _TEST_LOCATION = "us-central1" _TEST_PARENT = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" @@ -88,6 +89,9 @@ _TEST_METRICS = {_TEST_METRIC_KEY_1: 222, _TEST_METRIC_KEY_2: 1} _TEST_OTHER_METRICS = {_TEST_METRIC_KEY_2: 0.9} +# schema +_TEST_WRONG_SCHEMA_TITLE = "system.WrongSchema" + @pytest.fixture def get_metadata_store_mock(): @@ -113,6 +117,21 @@ def get_context_mock(): yield get_context_mock +@pytest.fixture +def get_context_wrong_schema_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_context_wrong_schema_mock: + get_context_wrong_schema_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + schema_title=_TEST_WRONG_SCHEMA_TITLE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + yield get_context_wrong_schema_mock + + @pytest.fixture def get_pipeline_context_mock(): with patch.object( @@ -160,6 +179,20 @@ def get_execution_mock(): yield get_execution_mock +@pytest.fixture +def get_execution_wrong_schema_mock(): + with patch.object( + MetadataServiceClient, "get_execution" + ) as get_execution_wrong_schema_mock: + get_execution_wrong_schema_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=_TEST_WRONG_SCHEMA_TITLE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + yield get_execution_wrong_schema_mock + + @pytest.fixture def update_execution_mock(): with patch.object( @@ -254,6 +287,20 @@ def get_artifact_mock(): yield get_artifact_mock +@pytest.fixture +def get_artifact_wrong_schema_mock(): + with patch.object( + MetadataServiceClient, "get_artifact" + ) as get_artifact_wrong_schema_mock: + get_artifact_wrong_schema_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=_TEST_WRONG_SCHEMA_TITLE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + yield get_artifact_wrong_schema_mock + + @pytest.fixture def update_artifact_mock(): with patch.object(MetadataServiceClient, "update_artifact") as update_artifact_mock: @@ -284,8 +331,8 @@ def _assert_frame_equal_with_sorted_columns(dataframe_1, dataframe_2): class TestMetadata: def setup_method(self): reload(initializer) - reload(aiplatform) reload(metadata) + reload(aiplatform) def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -300,6 +347,38 @@ def test_init_experiment_with_existing_metadataStore_and_context( get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_wrong_schema_mock") + def test_init_experiment_wrong_schema(self): + with pytest.raises(ValueError): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + ) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_init_experiment_reset(self): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.start_run(_TEST_RUN) + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + assert metadata.metadata_service.experiment_name == _TEST_EXPERIMENT + assert metadata.metadata_service.run_name == _TEST_RUN + + aiplatform.init(project=_TEST_OTHER_PROJECT, location=_TEST_LOCATION) + + assert metadata.metadata_service.experiment_name is None + assert metadata.metadata_service.run_name is None + @pytest.mark.usefixtures("get_metadata_store_mock") @pytest.mark.usefixtures("get_context_mock") def test_start_run_with_existing_execution_and_artifact( @@ -326,6 +405,28 @@ def test_start_run_with_existing_execution_and_artifact( events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], ) + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_wrong_schema_mock") + def test_start_run_with_wrong_run_execution_schema(self,): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + with pytest.raises(ValueError): + aiplatform.start_run(_TEST_RUN) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_wrong_schema_mock") + def test_start_run_with_wrong_metrics_artifact_schema(self,): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + with pytest.raises(ValueError): + aiplatform.start_run(_TEST_RUN) + @pytest.mark.usefixtures("get_metadata_store_mock") @pytest.mark.usefixtures("get_context_mock") @pytest.mark.usefixtures("get_execution_mock") From d6e65d20d7773689044cdb658621e1ac3c65039d Mon Sep 17 00:00:00 2001 From: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> Date: Tue, 4 May 2021 14:33:06 -0400 Subject: [PATCH 17/36] chore: merge main-test into dev-test (#362) --- .github/CODEOWNERS | 8 +- .pre-commit-config.yaml | 2 +- CONTRIBUTING.rst | 16 +- docs/aiplatform_v1beta1/services.rst | 1 + .../tensorboard_service.rst | 11 + google/cloud/aiplatform/base.py | 44 +- google/cloud/aiplatform/constants.py | 17 +- .../cloud/aiplatform/datasets/_datasources.py | 15 +- google/cloud/aiplatform/datasets/dataset.py | 4 +- .../aiplatform/datasets/image_dataset.py | 7 +- .../aiplatform/datasets/tabular_dataset.py | 3 +- .../cloud/aiplatform/datasets/text_dataset.py | 7 +- .../aiplatform/datasets/video_dataset.py | 7 +- google/cloud/aiplatform/initializer.py | 8 +- google/cloud/aiplatform/jobs.py | 38 +- google/cloud/aiplatform/models.py | 309 +- google/cloud/aiplatform/training_jobs.py | 156 +- google/cloud/aiplatform/training_utils.py | 2 +- google/cloud/aiplatform/utils.py | 54 +- .../predict/instance_v1/gapic_metadata.json | 7 + .../predict/params_v1/gapic_metadata.json | 7 + .../predict/prediction_v1/gapic_metadata.json | 7 + .../definition_v1/gapic_metadata.json | 7 + .../instance_v1beta1/gapic_metadata.json | 7 + .../params_v1beta1/gapic_metadata.json | 7 + .../prediction_v1beta1/gapic_metadata.json | 7 + .../definition_v1beta1/gapic_metadata.json | 7 + .../cloud/aiplatform_v1/gapic_metadata.json | 721 ++ .../services/migration_service/client.py | 22 +- google/cloud/aiplatform_v1beta1/__init__.py | 120 +- .../aiplatform_v1beta1/gapic_metadata.json | 1949 ++++ .../featurestore_service/async_client.py | 89 + .../services/featurestore_service/client.py | 90 + .../featurestore_service/transports/base.py | 14 + .../featurestore_service/transports/grpc.py | 29 + .../transports/grpc_asyncio.py | 30 + .../index_endpoint_service/async_client.py | 4 - .../services/index_endpoint_service/client.py | 16 - .../services/index_service/async_client.py | 2 +- .../services/index_service/client.py | 2 +- .../services/job_service/async_client.py | 32 +- .../services/job_service/client.py | 59 +- .../services/migration_service/client.py | 22 +- .../services/pipeline_service/async_client.py | 442 + .../services/pipeline_service/client.py | 543 ++ .../services/pipeline_service/pagers.py | 129 + .../pipeline_service/transports/base.py | 69 + .../pipeline_service/transports/grpc.py | 150 + .../transports/grpc_asyncio.py | 155 + .../services/tensorboard_service/__init__.py | 24 + .../tensorboard_service/async_client.py | 2346 +++++ .../services/tensorboard_service/client.py | 2647 ++++++ .../services/tensorboard_service/pagers.py | 700 ++ .../transports/__init__.py | 37 + .../tensorboard_service/transports/base.py | 509 ++ .../tensorboard_service/transports/grpc.py | 962 ++ .../transports/grpc_asyncio.py | 980 ++ .../aiplatform_v1beta1/types/__init__.py | 124 +- .../aiplatform_v1beta1/types/custom_job.py | 8 + .../aiplatform_v1beta1/types/entity_type.py | 3 + .../types/explanation_metadata.py | 2 +- .../cloud/aiplatform_v1beta1/types/feature.py | 4 + .../types/feature_monitoring_stats.py | 14 +- .../aiplatform_v1beta1/types/featurestore.py | 13 - .../types/featurestore_online_service.py | 56 - .../types/featurestore_service.py | 96 +- .../types/index_endpoint.py | 7 + .../aiplatform_v1beta1/types/index_service.py | 18 +- .../aiplatform_v1beta1/types/job_service.py | 38 +- .../types/metadata_service.py | 112 +- .../types/metadata_store.py | 20 + .../types/migration_service.py | 20 +- .../aiplatform_v1beta1/types/pipeline_job.py | 353 + .../types/pipeline_service.py | 140 + .../aiplatform_v1beta1/types/tensorboard.py | 108 + .../types/tensorboard_data.py | 161 + .../types/tensorboard_experiment.py | 95 + .../types/tensorboard_run.py | 74 + .../types/tensorboard_service.py | 892 ++ .../types/tensorboard_time_series.py | 123 + .../cloud/aiplatform_v1beta1/types/types.py | 6 +- .../types/user_action_reference.py | 2 +- .../cloud/aiplatform_v1beta1/types/value.py | 45 + noxfile.py | 3 - samples/model-builder/conftest.py | 115 +- ..._import_dataset_tabular_bigquery_sample.py | 36 + ...rt_dataset_tabular_bigquery_sample_test.py | 36 + ...e_and_import_dataset_tabular_gcs_sample.py | 37 + ..._import_dataset_tabular_gcs_sample_test.py | 36 + .../create_and_import_dataset_video_sample.py | 44 + ...te_and_import_dataset_video_sample_test.py | 42 + ...create_batch_prediction_job_sample_test.py | 2 +- .../model-builder/create_endpoint_sample.py | 34 + .../create_endpoint_sample_test.py | 36 + ...ate_training_pipeline_custom_job_sample.py | 69 + ...reate_training_pipeline_custom_job_test.py | 62 + ..._custom_training_managed_dataset_sample.py | 73 + ...ne_custom_training_managed_dataset_test.py | 70 + ...ng_pipeline_image_classification_sample.py | 6 +- ...peline_image_classification_sample_test.py | 5 +- ..._pipeline_tabular_classification_sample.py | 59 + ...line_tabular_classification_sample_test.py | 57 + ...ning_pipeline_tabular_regression_sample.py | 59 + ...pipeline_tabular_regression_sample_test.py | 57 + ...y_model_with_automatic_resources_sample.py | 57 + ...loy_model_with_automatic_resources_test.py | 52 + ...y_model_with_dedicated_resources_sample.py | 70 + ...loy_model_with_dedicated_resources_test.py | 62 + samples/model-builder/get_model_sample.py | 31 + samples/model-builder/get_model_test.py | 32 + ...data_text_entity_extraction_sample_test.py | 4 +- ...ata_text_sentiment_analysis_sample_test.py | 4 +- ...rt_data_video_action_recognition_sample.py | 45 + ...ta_video_action_recognition_sample_test.py | 44 + ...import_data_video_classification_sample.py | 46 + ...t_data_video_classification_sample_test.py | 42 + ...mport_data_video_object_tracking_sample.py | 45 + ..._data_video_object_tracking_sample_test.py | 44 + .../predict_tabular_classification_sample.py | 35 + ...dict_tabular_classification_sample_test.py | 33 + .../predict_tabular_regression_sample.py | 34 + .../predict_tabular_regression_sample_test.py | 33 + ...classification_single_label_sample_test.py | 4 +- ...dict_text_entity_extraction_sample_test.py | 4 +- ...ict_text_sentiment_analysis_sample_test.py | 4 +- samples/model-builder/test_constants.py | 98 + samples/model-builder/upload_model_sample.py | 71 + samples/model-builder/upload_model_test.py | 62 + samples/snippets/requirements.txt | 2 +- tests/__init__.py | 15 + tests/unit/__init__.py | 15 + tests/unit/aiplatform/test_endpoints.py | 3 + tests/unit/aiplatform/test_initializer.py | 5 + tests/unit/aiplatform/test_models.py | 286 +- tests/unit/aiplatform/test_training_jobs.py | 157 + tests/unit/gapic/__init__.py | 15 + .../aiplatform_v1/test_migration_service.py | 28 +- ...est_featurestore_online_serving_service.py | 1 - .../test_featurestore_service.py | 227 + .../test_index_endpoint_service.py | 51 +- .../aiplatform_v1beta1/test_job_service.py | 88 +- .../test_metadata_service.py | 12 +- .../test_migration_service.py | 28 +- .../test_pipeline_service.py | 1553 +++- .../test_tensorboard_service.py | 8115 +++++++++++++++++ 145 files changed, 28073 insertions(+), 615 deletions(-) create mode 100644 docs/aiplatform_v1beta1/tensorboard_service.rst create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_metadata.json create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_metadata.json create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_metadata.json create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_metadata.json create mode 100644 google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_metadata.json create mode 100644 google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_metadata.json create mode 100644 google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_metadata.json create mode 100644 google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_metadata.json create mode 100644 google/cloud/aiplatform_v1/gapic_metadata.json create mode 100644 google/cloud/aiplatform_v1beta1/gapic_metadata.json create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/types/pipeline_job.py create mode 100644 google/cloud/aiplatform_v1beta1/types/tensorboard.py create mode 100644 google/cloud/aiplatform_v1beta1/types/tensorboard_data.py create mode 100644 google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py create mode 100644 google/cloud/aiplatform_v1beta1/types/tensorboard_run.py create mode 100644 google/cloud/aiplatform_v1beta1/types/tensorboard_service.py create mode 100644 google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py create mode 100644 google/cloud/aiplatform_v1beta1/types/value.py create mode 100644 samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py create mode 100644 samples/model-builder/create_and_import_dataset_tabular_bigquery_sample_test.py create mode 100644 samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py create mode 100644 samples/model-builder/create_and_import_dataset_tabular_gcs_sample_test.py create mode 100644 samples/model-builder/create_and_import_dataset_video_sample.py create mode 100644 samples/model-builder/create_and_import_dataset_video_sample_test.py create mode 100644 samples/model-builder/create_endpoint_sample.py create mode 100644 samples/model-builder/create_endpoint_sample_test.py create mode 100644 samples/model-builder/create_training_pipeline_custom_job_sample.py create mode 100644 samples/model-builder/create_training_pipeline_custom_job_test.py create mode 100644 samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py create mode 100644 samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py create mode 100644 samples/model-builder/create_training_pipeline_tabular_classification_sample.py create mode 100644 samples/model-builder/create_training_pipeline_tabular_classification_sample_test.py create mode 100644 samples/model-builder/create_training_pipeline_tabular_regression_sample.py create mode 100644 samples/model-builder/create_training_pipeline_tabular_regression_sample_test.py create mode 100644 samples/model-builder/deploy_model_with_automatic_resources_sample.py create mode 100644 samples/model-builder/deploy_model_with_automatic_resources_test.py create mode 100644 samples/model-builder/deploy_model_with_dedicated_resources_sample.py create mode 100644 samples/model-builder/deploy_model_with_dedicated_resources_test.py create mode 100644 samples/model-builder/get_model_sample.py create mode 100644 samples/model-builder/get_model_test.py create mode 100644 samples/model-builder/import_data_video_action_recognition_sample.py create mode 100644 samples/model-builder/import_data_video_action_recognition_sample_test.py create mode 100644 samples/model-builder/import_data_video_classification_sample.py create mode 100644 samples/model-builder/import_data_video_classification_sample_test.py create mode 100644 samples/model-builder/import_data_video_object_tracking_sample.py create mode 100644 samples/model-builder/import_data_video_object_tracking_sample_test.py create mode 100644 samples/model-builder/predict_tabular_classification_sample.py create mode 100644 samples/model-builder/predict_tabular_classification_sample_test.py create mode 100644 samples/model-builder/predict_tabular_regression_sample.py create mode 100644 samples/model-builder/predict_tabular_regression_sample_test.py create mode 100644 samples/model-builder/upload_model_sample.py create mode 100644 samples/model-builder/upload_model_test.py create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/gapic/__init__.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 8de778714c..ccdc098900 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -14,8 +14,8 @@ /google/cloud/aiplatform/* @googleapis/cloud-aiplatform-model-builder-sdk /tests/unit/aiplatform/* @googleapis/cloud-aiplatform-model-builder-sdk -# The python-samples-owners team is the default owner for samples -/samples/**/*.py @dizcology @googleapis/python-samples-owners +# The Cloud AI DPE team is the default owner for samples +/samples/**/*.py @googleapis/cdpe-cloudai @googleapis/python-samples-owners -# The enhanced client library tests are owned by @telpirion -/tests/unit/enhanced_library/*.py @telpirion +# The enhanced client library tests are owned by Cloud AI DPE +/tests/unit/enhanced_library/*.py @googleapis/cdpe-cloudai diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8912e9b5d7..1bbd787833 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,6 @@ repos: hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.9.0 + rev: 3.9.1 hooks: - id: flake8 diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 66216c172d..f865e3769d 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -160,21 +160,7 @@ Running System Tests auth settings and change some configuration in your project to run all the tests. -- System tests will be run against an actual project and - so you'll need to provide some environment variables to facilitate - authentication to your project: - - - ``GOOGLE_APPLICATION_CREDENTIALS``: The path to a JSON key file; - Such a file can be downloaded directly from the developer's console by clicking - "Generate new JSON key". See private key - `docs `__ - for more details. - -- Once you have downloaded your json keys, set the environment variable - ``GOOGLE_APPLICATION_CREDENTIALS`` to the absolute path of the json file:: - - $ export GOOGLE_APPLICATION_CREDENTIALS="/Users//path/to/app_credentials.json" - +- System tests will be run against an actual project. You should use local credentials from gcloud when possible. See `Best practices for application authentication `__. Some tests require a service account. For those tests see `Authenticating as a service account `__. ************* Test Coverage diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index f715a7c1f4..490112c7d9 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -16,4 +16,5 @@ Services for Google Cloud Aiplatform v1beta1 API pipeline_service prediction_service specialist_pool_service + tensorboard_service vizier_service diff --git a/docs/aiplatform_v1beta1/tensorboard_service.rst b/docs/aiplatform_v1beta1/tensorboard_service.rst new file mode 100644 index 0000000000..423efcd796 --- /dev/null +++ b/docs/aiplatform_v1beta1/tensorboard_service.rst @@ -0,0 +1,11 @@ +TensorboardService +------------------------------------ + +.. automodule:: google.cloud.aiplatform_v1beta1.services.tensorboard_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers + :members: + :inherited-members: diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 907397b7e8..f46db9c47e 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -94,7 +94,6 @@ def log_create_complete( resource (proto.Message): AI Platform Resourc proto.Message variable_name (str): Name of variable to use for code snippet - """ self._logger.info(f"{cls.__name__} created. Resource name: {resource.name}") self._logger.info(f"To use this {cls.__name__} in another session:") @@ -181,7 +180,8 @@ def _raise_future_exception(self): raise self._exception def _complete_future(self, future: futures.Future): - """Checks for exception of future and removes the pointer if it's still latest. + """Checks for exception of future and removes the pointer if it's still + latest. Args: future (futures.Future): Required. A future to complete. @@ -215,13 +215,14 @@ def wait(self): @property def _latest_future(self) -> Optional[futures.Future]: - """Get the latest future if it exists""" + """Get the latest future if it exists.""" with self.__latest_future_lock: return self.__latest_future @_latest_future.setter def _latest_future(self, future: Optional[futures.Future]): - """Optionally set the latest future and add a complete_future callback.""" + """Optionally set the latest future and add a complete_future + callback.""" with self.__latest_future_lock: self.__latest_future = future if future: @@ -260,7 +261,8 @@ def wait_for_dependencies_and_invoke( kwargs: Dict[str, Any], internal_callbacks: Iterable[Callable[[Any], Any]], ) -> Any: - """Wrapper method to wait on any dependencies before submitting method. + """Wrapper method to wait on any dependencies before submitting + method. Args: deps (Sequence[futures.Future]): @@ -272,7 +274,6 @@ def wait_for_dependencies_and_invoke( Required. The keyword arguments to call the method with. internal_callbacks: (Callable[[Any], Any]): Callbacks that take the result of method. - """ for future in set(deps): @@ -342,12 +343,14 @@ def wait_for_dependencies_and_invoke( @classmethod @abc.abstractmethod def _empty_constructor(cls) -> "FutureManager": - """Should construct object with all non FutureManager attributes as None""" + """Should construct object with all non FutureManager attributes as + None.""" pass @abc.abstractmethod def _sync_object_with_future_result(self, result: "FutureManager"): - """Should sync the object from _empty_constructor with result of future.""" + """Should sync the object from _empty_constructor with result of + future.""" def __repr__(self) -> str: if self._exception: @@ -375,7 +378,8 @@ class AiPlatformResourceNoun(metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod def client_class(cls) -> Type[utils.AiPlatformServiceClientWithOverride]: - """Client class required to interact with resource with optional overrides.""" + """Client class required to interact with resource with optional + overrides.""" pass @property @@ -388,7 +392,8 @@ def _is_client_prediction_client(cls) -> bool: @property @abc.abstractmethod def _getter_method(cls) -> str: - """Name of getter method of client class for retrieving the resource.""" + """Name of getter method of client class for retrieving the + resource.""" pass @property @@ -400,7 +405,7 @@ def _delete_method(cls) -> str: @property @abc.abstractmethod def _resource_noun(cls) -> str: - """Resource noun""" + """Resource noun.""" pass def __init__( @@ -547,7 +552,8 @@ def optional_sync( return_input_arg: Optional[str] = None, bind_future_to_self: bool = True, ): - """Decorator for AiPlatformResourceNounWithFutureManager with optional sync support. + """Decorator for AiPlatformResourceNounWithFutureManager with optional sync + support. Methods with this decorator should include a "sync" argument that defaults to True. If called with sync=False this decorator will launch the method as a @@ -681,7 +687,8 @@ def wrapper(*args, **kwargs): class AiPlatformResourceNounWithFutureManager(AiPlatformResourceNoun, FutureManager): - """Allows optional asynchronous calls to this AI Platform Resource Nouns.""" + """Allows optional asynchronous calls to this AI Platform Resource + Nouns.""" def __init__( self, @@ -816,7 +823,8 @@ def _list( credentials: Optional[auth_credentials.Credentials] = None, ) -> List[AiPlatformResourceNoun]: """Private method to list all instances of this AI Platform Resource, - takes a `cls_filter` arg to filter to a particular SDK resource subclass. + takes a `cls_filter` arg to filter to a particular SDK resource + subclass. Args: cls_filter (Callable[[proto.Message], bool]): @@ -884,8 +892,9 @@ def _list_with_local_order( credentials: Optional[auth_credentials.Credentials] = None, ) -> List[AiPlatformResourceNoun]: """Private method to list all instances of this AI Platform Resource, - takes a `cls_filter` arg to filter to a particular SDK resource subclass. - Provides client-side sorting when a list API doesn't support `order_by`. + takes a `cls_filter` arg to filter to a particular SDK resource + subclass. Provides client-side sorting when a list API doesn't support + `order_by`. Args: cls_filter (Callable[[proto.Message], bool]): @@ -986,7 +995,8 @@ def list( @optional_sync() def delete(self, sync: bool = True) -> None: - """Deletes this AI Platform resource. WARNING: This deletion is permament. + """Deletes this AI Platform resource. WARNING: This deletion is + permament. Args: sync (bool): diff --git a/google/cloud/aiplatform/constants.py b/google/cloud/aiplatform/constants.py index 62c28009c2..9e71427d0f 100644 --- a/google/cloud/aiplatform/constants.py +++ b/google/cloud/aiplatform/constants.py @@ -16,7 +16,22 @@ # DEFAULT_REGION = "us-central1" -SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1") +SUPPORTED_REGIONS = { + "asia-east1", + "asia-northeast1", + "asia-northeast3", + "asia-southeast1", + "australia-southeast1", + "europe-west1", + "europe-west2", + "europe-west4", + "northamerica-northeast1", + "us-central1", + "us-east1", + "us-east4", + "us-west1", +} + API_BASE_PATH = "aiplatform.googleapis.com" # Batch Prediction diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py index eefd1b04fd..23a89cc157 100644 --- a/google/cloud/aiplatform/datasets/_datasources.py +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -26,7 +26,7 @@ class Datasource(abc.ABC): - """An abstract class that sets dataset_metadata""" + """An abstract class that sets dataset_metadata.""" @property @abc.abstractmethod @@ -36,7 +36,7 @@ def dataset_metadata(self): class DatasourceImportable(abc.ABC): - """An abstract class that sets import_data_config""" + """An abstract class that sets import_data_config.""" @property @abc.abstractmethod @@ -46,14 +46,14 @@ def import_data_config(self): class TabularDatasource(Datasource): - """Datasource for creating a tabular dataset for AI Platform""" + """Datasource for creating a tabular dataset for AI Platform.""" def __init__( self, gcs_source: Optional[Union[str, Sequence[str]]] = None, bq_source: Optional[str] = None, ): - """Creates a tabular datasource + """Creates a tabular datasource. Args: gcs_source (Union[str, Sequence[str]]): @@ -99,7 +99,7 @@ def dataset_metadata(self) -> Optional[Dict]: class NonTabularDatasource(Datasource): - """Datasource for creating an empty non-tabular dataset for AI Platform""" + """Datasource for creating an empty non-tabular dataset for AI Platform.""" @property def dataset_metadata(self) -> Optional[Dict]: @@ -107,7 +107,8 @@ def dataset_metadata(self) -> Optional[Dict]: class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable): - """Datasource for creating a non-tabular dataset for AI Platform and importing data to the dataset""" + """Datasource for creating a non-tabular dataset for AI Platform and + importing data to the dataset.""" def __init__( self, @@ -115,7 +116,7 @@ def __init__( import_schema_uri: str, data_item_labels: Optional[Dict] = None, ): - """Creates a non-tabular datasource + """Creates a non-tabular datasource. Args: gcs_source (Union[str, Sequence[str]]): diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 25078ab2c5..44dadc4ee4 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -36,7 +36,7 @@ class _Dataset(base.AiPlatformResourceNounWithFutureManager): - """Managed dataset resource for AI Platform""" + """Managed dataset resource for AI Platform.""" client_class = utils.DatasetClientWithOverride _is_client_prediction_client = False @@ -70,7 +70,6 @@ def __init__( credentials (auth_credentials.Credentials): Custom credentials to use to upload this model. Overrides credentials set in aiplatform.init. - """ super().__init__( @@ -195,7 +194,6 @@ def create( Returns: dataset (Dataset): Instantiated representation of the managed dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/image_dataset.py b/google/cloud/aiplatform/datasets/image_dataset.py index 32db96bea1..c2b3ca68b5 100644 --- a/google/cloud/aiplatform/datasets/image_dataset.py +++ b/google/cloud/aiplatform/datasets/image_dataset.py @@ -27,7 +27,7 @@ class ImageDataset(datasets._Dataset): - """Managed image dataset resource for AI Platform""" + """Managed image dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.image, @@ -47,8 +47,8 @@ def create( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> "ImageDataset": - """Creates a new image dataset and optionally imports data into dataset when - source and import_schema_uri are passed. + """Creates a new image dataset and optionally imports data into dataset + when source and import_schema_uri are passed. Args: display_name (str): @@ -114,7 +114,6 @@ def create( Returns: image_dataset (ImageDataset): Instantiated representation of the managed image dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 3dd217aad7..06ba4a3394 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -27,7 +27,7 @@ class TabularDataset(datasets._Dataset): - """Managed tabular dataset resource for AI Platform""" + """Managed tabular dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.tabular, @@ -95,7 +95,6 @@ def create( Returns: tabular_dataset (TabularDataset): Instantiated representation of the managed tabular dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/text_dataset.py b/google/cloud/aiplatform/datasets/text_dataset.py index c27fed59ad..6f6fd57bda 100644 --- a/google/cloud/aiplatform/datasets/text_dataset.py +++ b/google/cloud/aiplatform/datasets/text_dataset.py @@ -27,7 +27,7 @@ class TextDataset(datasets._Dataset): - """Managed text dataset resource for AI Platform""" + """Managed text dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.text, @@ -47,8 +47,8 @@ def create( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> "TextDataset": - """Creates a new text dataset and optionally imports data into dataset when - source and import_schema_uri are passed. + """Creates a new text dataset and optionally imports data into dataset + when source and import_schema_uri are passed. Example Usage: ds = aiplatform.TextDataset.create( @@ -121,7 +121,6 @@ def create( Returns: text_dataset (TextDataset): Instantiated representation of the managed text dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/video_dataset.py b/google/cloud/aiplatform/datasets/video_dataset.py index 84af000df4..7064c8b7cf 100644 --- a/google/cloud/aiplatform/datasets/video_dataset.py +++ b/google/cloud/aiplatform/datasets/video_dataset.py @@ -27,7 +27,7 @@ class VideoDataset(datasets._Dataset): - """Managed video dataset resource for AI Platform""" + """Managed video dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.video, @@ -47,8 +47,8 @@ def create( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> "VideoDataset": - """Creates a new video dataset and optionally imports data into dataset when - source and import_schema_uri are passed. + """Creates a new video dataset and optionally imports data into dataset + when source and import_schema_uri are passed. Args: display_name (str): @@ -114,7 +114,6 @@ def create( Returns: video_dataset (VideoDataset): Instantiated representation of the managed video dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index bbdf4d4aa9..eecbac61c7 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -114,8 +114,9 @@ def get_encryption_spec( gca_encryption_spec_v1beta1.EncryptionSpec, ] ]: - """Creates a gca_encryption_spec.EncryptionSpec instance from the given key name. - If the provided key name is None, it uses the default key name if provided. + """Creates a gca_encryption_spec.EncryptionSpec instance from the given + key name. If the provided key name is None, it uses the default key + name if provided. Args: encryption_spec_key_name (Optional[str]): The default encryption key name to use when creating resources. @@ -243,7 +244,8 @@ def create_client( location_override: Optional[str] = None, prediction_client: bool = False, ) -> utils.AiPlatformServiceClientWithOverride: - """Instantiates a given AiPlatformServiceClient with optional overrides. + """Instantiates a given AiPlatformServiceClient with optional + overrides. Args: client_class (utils.AiPlatformServiceClientWithOverride): diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index a7f2bbd31d..ee6d46dde9 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -64,8 +64,7 @@ class _Job(base.AiPlatformResourceNounWithFutureManager): - """ - Class that represents a general Job resource in AI Platform (Unified). + """Class that represents a general Job resource in AI Platform (Unified). Cannot be directly instantiated. Serves as base class to specific Job types, i.e. BatchPredictionJob or @@ -79,7 +78,7 @@ class _Job(base.AiPlatformResourceNounWithFutureManager): _delete_method (str): The name of the specific JobServiceClient delete method """ - client_class = utils.JobpointClientWithOverride + client_class = utils.JobClientWithOverride _is_client_prediction_client = False def __init__( @@ -89,8 +88,8 @@ def __init__( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ): - """ - Retrives Job subclass resource by calling a subclass-specific getter method. + """Retrives Job subclass resource by calling a subclass-specific getter + method. Args: job_name (str): @@ -142,7 +141,8 @@ def _cancel_method(cls) -> str: pass def _dashboard_uri(self) -> Optional[str]: - """Helper method to compose the dashboard uri where job can be viewed.""" + """Helper method to compose the dashboard uri where job can be + viewed.""" fields = utils.extract_fields_from_resource_name(self.resource_name) url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}" return url @@ -152,7 +152,6 @@ def _block_until_complete(self): Raises: RuntimeError: If job failed or cancelled. - """ # Used these numbers so failures surface fast @@ -232,8 +231,11 @@ def list( ) def cancel(self) -> None: - """Cancels this Job. Success of cancellation is not guaranteed. Use `Job.state` - property to verify if cancellation was successful.""" + """Cancels this Job. + + Success of cancellation is not guaranteed. Use `Job.state` + property to verify if cancellation was successful. + """ _LOGGER.log_action_start_against_resource("Cancelling", "run", self) getattr(self.api_client, self._cancel_method)(name=self.resource_name) @@ -255,8 +257,8 @@ def __init__( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ): - """ - Retrieves a BatchPredictionJob resource and instantiates its representation. + """Retrieves a BatchPredictionJob resource and instantiates its + representation. Args: batch_prediction_job_name (str): @@ -463,7 +465,6 @@ def create( Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. - """ utils.validate_display_name(job_display_name) @@ -655,7 +656,6 @@ def _create( If no or multiple source or destinations are provided. Also, if provided instances_format or predictions_format are not supported by AI Platform. - """ # select v1beta1 if explain else use default v1 if generate_explanation: @@ -687,9 +687,9 @@ def _create( def iter_outputs( self, bq_max_results: Optional[int] = 100 ) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: - """Returns an Iterable object to traverse the output files, either a list - of GCS Blobs or a BigQuery RowIterator depending on the output config set - when the BatchPredictionJob was created. + """Returns an Iterable object to traverse the output files, either a + list of GCS Blobs or a BigQuery RowIterator depending on the output + config set when the BatchPredictionJob was created. Args: bq_max_results: Optional[int] = 100 @@ -724,7 +724,8 @@ def iter_outputs( # Build a Storage Client using the same credentials as JobServiceClient storage_client = storage.Client( - credentials=self.api_client._transport._credentials + project=self.project, + credentials=self.api_client._transport._credentials, ) gcs_bucket, gcs_prefix = utils.extract_bucket_and_prefix_from_gcs_path( @@ -740,7 +741,8 @@ def iter_outputs( # Build a BigQuery Client using the same credentials as JobServiceClient bq_client = bigquery.Client( - credentials=self.api_client._transport._credentials + project=self.project, + credentials=self.api_client._transport._credentials, ) # Format from service is `bq://projectId.bqDatasetId` diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index d96b681695..cecc992644 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -17,6 +17,7 @@ import proto from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union +from google.api_core import operation from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base @@ -35,9 +36,11 @@ endpoint_v1 as gca_endpoint_v1, endpoint_v1beta1 as gca_endpoint_v1beta1, explanation_v1beta1 as gca_explanation_v1beta1, + io as gca_io_compat, machine_resources as gca_machine_resources_compat, machine_resources_v1beta1 as gca_machine_resources_v1beta1, model as gca_model_compat, + model_service as gca_model_service_compat, model_v1beta1 as gca_model_v1beta1, env_var as gca_env_var_compat, env_var_v1beta1 as gca_env_var_v1beta1, @@ -217,8 +220,8 @@ def _create( encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, sync=True, ) -> "Endpoint": - """ - Creates a new endpoint by calling the API client. + """Creates a new endpoint by calling the API client. + Args: api_client (EndpointServiceClient): Required. An instance of EndpointServiceClient with the correct @@ -296,9 +299,8 @@ def _create( def _allocate_traffic( traffic_split: Dict[str, int], traffic_percentage: int, ) -> Dict[str, int]: - """ - Allocates desired traffic to new deployed model and scales traffic of - older deployed models. + """Allocates desired traffic to new deployed model and scales traffic + of older deployed models. Args: traffic_split (Dict[str, int]): @@ -333,9 +335,8 @@ def _allocate_traffic( def _unallocate_traffic( traffic_split: Dict[str, int], deployed_model_id: str, ) -> Dict[str, int]: - """ - Sets deployed model id's traffic to 0 and scales the traffic of other - deployed models. + """Sets deployed model id's traffic to 0 and scales the traffic of + other deployed models. Args: traffic_split (Dict[str, int]): @@ -402,7 +403,7 @@ def _validate_deploy_args( accelerator_type (str): Required. Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, - NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_T4 deployed_model_display_name (str): Required. The display name of the DeployedModel. If not provided upon creation, the Model's display_name is used. @@ -431,11 +432,11 @@ def _validate_deploy_args( For more details, see `Ref docs ` Raises: - ValueError if Min or Max replica is negative. Traffic percentage > 100 or - < 0. Or if traffic_split does not sum to 100. + ValueError: if Min or Max replica is negative. Traffic percentage > 100 or + < 0. Or if traffic_split does not sum to 100. - ValueError if either explanation_metadata or explanation_parameters - but not both are specified. + ValueError: if either explanation_metadata or explanation_parameters + but not both are specified. """ if min_replica_count < 0: raise ValueError("Min replica cannot be negative.") @@ -477,13 +478,13 @@ def deploy( max_replica_count: int = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, + service_account: Optional[str] = None, explanation_metadata: Optional[explain.ExplanationMetadata] = None, explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), sync=True, ) -> None: - """ - Deploys a Model to the Endpoint. + """Deploys a Model to the Endpoint. Args: model (aiplatform.Model): @@ -528,9 +529,16 @@ def deploy( accelerator_type (str): Optional. Hardware accelerator type. Must also set accelerator_count if used. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, - NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. + service_account (str): + The service account that the DeployedModel's container runs as. Specify the + email address of the service account. If this service account is not + specified, the container runs as a service account that doesn't have access + to the resource project. + Users deploying the Model must have the `iam.serviceAccounts.actAs` + permission on this service account. explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be @@ -569,6 +577,7 @@ def deploy( max_replica_count=max_replica_count, accelerator_type=accelerator_type, accelerator_count=accelerator_count, + service_account=service_account, explanation_metadata=explanation_metadata, explanation_parameters=explanation_parameters, metadata=metadata, @@ -583,17 +592,17 @@ def _deploy( traffic_percentage: Optional[int] = 0, traffic_split: Optional[Dict[str, int]] = None, machine_type: Optional[str] = None, - min_replica_count: Optional[int] = 1, - max_replica_count: Optional[int] = 1, + min_replica_count: int = 1, + max_replica_count: int = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, + service_account: Optional[str] = None, explanation_metadata: Optional[explain.ExplanationMetadata] = None, explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), sync=True, ) -> None: - """ - Deploys a Model to the Endpoint. + """Deploys a Model to the Endpoint. Args: model (aiplatform.Model): @@ -638,9 +647,16 @@ def _deploy( accelerator_type (str): Optional. Hardware accelerator type. Must also set accelerator_count if used. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, - NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. + service_account (str): + The service account that the DeployedModel's container runs as. Specify the + email address of the service account. If this service account is not + specified, the container runs as a service account that doesn't have access + to the resource project. + Users deploying the Model must have the `iam.serviceAccounts.actAs` + permission on this service account. explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be @@ -677,6 +693,7 @@ def _deploy( max_replica_count=max_replica_count, accelerator_type=accelerator_type, accelerator_count=accelerator_count, + service_account=service_account, explanation_metadata=explanation_metadata, explanation_parameters=explanation_parameters, metadata=metadata, @@ -697,10 +714,11 @@ def _deploy_call( traffic_percentage: Optional[int] = 0, traffic_split: Optional[Dict[str, int]] = None, machine_type: Optional[str] = None, - min_replica_count: Optional[int] = 1, - max_replica_count: Optional[int] = 1, + min_replica_count: int = 1, + max_replica_count: int = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, + service_account: Optional[str] = None, explanation_metadata: Optional[explain.ExplanationMetadata] = None, explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), @@ -753,6 +771,13 @@ def _deploy_call( is not provided, the larger value of min_replica_count or 1 will be used. If value provided is smaller than min_replica_count, it will automatically be increased to be min_replica_count. + service_account (str): + The service account that the DeployedModel's container runs as. Specify the + email address of the service account. If this service account is not + specified, the container runs as a service account that doesn't have access + to the resource project. + Users deploying the Model must have the `iam.serviceAccounts.actAs` + permission on this service account. explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be @@ -769,9 +794,9 @@ def _deploy_call( will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. Raises: - ValueError if there is not current traffic split and traffic percentage + ValueError: If there is not current traffic split and traffic percentage is not 0 or 100. - ValueError if only `explanation_metadata` or `explanation_parameters` + ValueError: If only `explanation_metadata` or `explanation_parameters` is specified. """ @@ -788,6 +813,12 @@ def _deploy_call( gca_endpoint = gca_endpoint_v1beta1 gca_machine_resources = gca_machine_resources_v1beta1 + deployed_model = gca_endpoint.DeployedModel( + model=model_resource_name, + display_name=deployed_model_display_name, + service_account=service_account, + ) + if machine_type: machine_spec = gca_machine_resources.MachineSpec(machine_type=machine_type) @@ -796,26 +827,17 @@ def _deploy_call( machine_spec.accelerator_type = accelerator_type machine_spec.accelerator_count = accelerator_count - dedicated_resources = gca_machine_resources.DedicatedResources( + deployed_model.dedicated_resources = gca_machine_resources.DedicatedResources( machine_spec=machine_spec, min_replica_count=min_replica_count, max_replica_count=max_replica_count, ) - deployed_model = gca_endpoint.DeployedModel( - dedicated_resources=dedicated_resources, - model=model_resource_name, - display_name=deployed_model_display_name, - ) + else: - automatic_resources = gca_machine_resources.AutomaticResources( + deployed_model.automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=min_replica_count, max_replica_count=max_replica_count, ) - deployed_model = gca_endpoint.DeployedModel( - automatic_resources=automatic_resources, - model=model_resource_name, - display_name=deployed_model_display_name, - ) # Service will throw error if both metadata and parameters are not provided if explanation_metadata and explanation_parameters: @@ -964,7 +986,8 @@ def _instantiate_prediction_client( credentials: Optional[auth_credentials.Credentials] = None, ) -> utils.PredictionClientWithOverride: - """Helper method to instantiates prediction client with optional overrides for this endpoint. + """Helper method to instantiates prediction client with optional + overrides for this endpoint. Args: location (str): The location of this endpoint. @@ -1007,7 +1030,6 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict ``parameters_schema_uri``. Returns: prediction: Prediction with returned predictions and Model Id. - """ self.wait() @@ -1198,6 +1220,26 @@ def description(self): """Description of the model.""" return self._gca_resource.description + @property + def supported_export_formats( + self, + ) -> Dict[str, List[gca_model_compat.Model.ExportFormat.ExportableContent]]: + """The formats and content types in which this Model may be exported. + If empty, this Model is not available for export. + + For example, if this model can be exported as a Tensorflow SavedModel and + have the artifacts written to Cloud Storage, the expected value would be: + + {'tf-saved-model': []} + """ + return { + export_format.id: [ + gca_model_compat.Model.ExportFormat.ExportableContent(content) + for content in export_format.exportable_contents + ] + for export_format in self._gca_resource.supported_export_formats + } + def __init__( self, model_name: str, @@ -1259,7 +1301,8 @@ def upload( encryption_spec_key_name: Optional[str] = None, sync=True, ) -> "Model": - """Uploads a model and returns a Model representing the uploaded Model resource. + """Uploads a model and returns a Model representing the uploaded Model + resource. Example usage: @@ -1392,7 +1435,7 @@ def upload( Returns: model: Instantiated representation of the uploaded model resource. Raises: - ValueError if only `explanation_metadata` or `explanation_parameters` + ValueError: If only `explanation_metadata` or `explanation_parameters` is specified. """ utils.validate_display_name(display_name) @@ -1489,18 +1532,18 @@ def deploy( traffic_percentage: Optional[int] = 0, traffic_split: Optional[Dict[str, int]] = None, machine_type: Optional[str] = None, - min_replica_count: Optional[int] = 1, - max_replica_count: Optional[int] = 1, + min_replica_count: int = 1, + max_replica_count: int = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, + service_account: Optional[str] = None, explanation_metadata: Optional[explain.ExplanationMetadata] = None, explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec_key_name: Optional[str] = None, sync=True, ) -> Endpoint: - """ - Deploys model to endpoint. Endpoint will be created if unspecified. + """Deploys model to endpoint. Endpoint will be created if unspecified. Args: endpoint ("Endpoint"): @@ -1545,9 +1588,16 @@ def deploy( accelerator_type (str): Optional. Hardware accelerator type. Must also set accelerator_count if used. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, - NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. + service_account (str): + The service account that the DeployedModel's container runs as. Specify the + email address of the service account. If this service account is not + specified, the container runs as a service account that doesn't have access + to the resource project. + Users deploying the Model must have the `iam.serviceAccounts.actAs` + permission on this service account. explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be @@ -1577,7 +1627,6 @@ def deploy( Returns: endpoint ("Endpoint"): Endpoint with the deployed model. - """ Endpoint._validate_deploy_args( @@ -1601,6 +1650,7 @@ def deploy( max_replica_count=max_replica_count, accelerator_type=accelerator_type, accelerator_count=accelerator_count, + service_account=service_account, explanation_metadata=explanation_metadata, explanation_parameters=explanation_parameters, metadata=metadata, @@ -1617,18 +1667,18 @@ def _deploy( traffic_percentage: Optional[int] = 0, traffic_split: Optional[Dict[str, int]] = None, machine_type: Optional[str] = None, - min_replica_count: Optional[int] = 1, - max_replica_count: Optional[int] = 1, + min_replica_count: int = 1, + max_replica_count: int = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, + service_account: Optional[str] = None, explanation_metadata: Optional[explain.ExplanationMetadata] = None, explanation_parameters: Optional[explain.ExplanationParameters] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> Endpoint: - """ - Deploys model to endpoint. Endpoint will be created if unspecified. + """Deploys model to endpoint. Endpoint will be created if unspecified. Args: endpoint ("Endpoint"): @@ -1673,9 +1723,16 @@ def _deploy( accelerator_type (str): Optional. Hardware accelerator type. Must also set accelerator_count if used. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, - NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 accelerator_count (int): Optional. The number of accelerators to attach to a worker replica. + service_account (str): + The service account that the DeployedModel's container runs as. Specify the + email address of the service account. If this service account is not + specified, the container runs as a service account that doesn't have access + to the resource project. + Users deploying the Model must have the `iam.serviceAccounts.actAs` + permission on this service account. explanation_metadata (explain.ExplanationMetadata): Optional. Metadata describing the Model's input and output for explanation. Both `explanation_metadata` and `explanation_parameters` must be @@ -1732,6 +1789,7 @@ def _deploy( max_replica_count=max_replica_count, accelerator_type=accelerator_type, accelerator_count=accelerator_count, + service_account=service_account, explanation_metadata=explanation_metadata, explanation_parameters=explanation_parameters, metadata=metadata, @@ -1766,9 +1824,10 @@ def batch_predict( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> jobs.BatchPredictionJob: - """Creates a batch prediction job using this Model and outputs prediction - results to the provided destination prefix in the specified - `predictions_format`. One source and one destination prefix are required. + """Creates a batch prediction job using this Model and outputs + prediction results to the provided destination prefix in the specified + `predictions_format`. One source and one destination prefix are + required. Example usage: @@ -1919,7 +1978,6 @@ def batch_predict( Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. - """ self.wait() @@ -1995,3 +2053,142 @@ def list( location=location, credentials=credentials, ) + + @base.optional_sync() + def _wait_on_export(self, operation_future: operation.Operation, sync=True) -> None: + operation_future.result() + + def export_model( + self, + export_format_id: str, + artifact_destination: Optional[str] = None, + image_destination: Optional[str] = None, + sync: bool = True, + ) -> Dict[str, str]: + """Exports a trained, exportable Model to a location specified by the user. + A Model is considered to be exportable if it has at least one `supported_export_formats`. + Either `artifact_destination` or `image_destination` must be provided. + + Usage: + my_model.export( + export_format_id='tf-saved-model' + artifact_destination='gs://my-bucket/models/' + ) + + or + + my_model.export( + export_format_id='custom-model' + image_destination='us-central1-docker.pkg.dev/projectId/repo/image' + ) + + Args: + export_format_id (str): + Required. The ID of the format in which the Model must be exported. + The list of export formats that this Model supports can be found + by calling `Model.supported_export_formats`. + artifact_destination (str): + The Cloud Storage location where the Model artifact is to be + written to. Under the directory given as the destination a + new one with name + "``model-export--``", + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format, will be created. Inside, the Model and any of its + supporting files will be written. + + This field should only be set when, in [Model.supported_export_formats], + the value for the key given in `export_format_id` contains ``ARTIFACT``. + image_destination (str): + The Google Container Registry or Artifact Registry URI where + the Model container image will be copied to. Accepted forms: + + - Google Container Registry path. For example: + ``gcr.io/projectId/imageName:tag``. + + - Artifact Registry path. For example: + ``us-central1-docker.pkg.dev/projectId/repoName/imageName:tag``. + + This field should only be set when, in [Model.supported_export_formats], + the value for the key given in `export_format_id` contains ``IMAGE``. + sync (bool): + Whether to execute this export synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + output_info (Dict[str, str]): + Details of the completed export with output destination paths to + the artifacts or container image. + Raises: + ValueError if model does not support exporting. + + ValueError if invalid arguments or export formats are provided. + """ + + # Model does not support exporting + if not self.supported_export_formats: + raise ValueError(f"The model `{self.resource_name}` is not exportable.") + + # No destination provided + if not any((artifact_destination, image_destination)): + raise ValueError( + "Please provide an `artifact_destination` or `image_destination`." + ) + + export_format_id = export_format_id.lower() + + # Unsupported export type + if export_format_id not in self.supported_export_formats: + raise ValueError( + f"'{export_format_id}' is not a supported export format for this model. " + f"Choose one of the following: {self.supported_export_formats}" + ) + + content_types = gca_model_compat.Model.ExportFormat.ExportableContent + supported_content_types = self.supported_export_formats[export_format_id] + + if ( + artifact_destination + and content_types.ARTIFACT not in supported_content_types + ): + raise ValueError( + "This model can not be exported as an artifact in '{export_format_id}' format. " + "Try exporting as a container image by passing the `image_destination` argument." + ) + + if image_destination and content_types.IMAGE not in supported_content_types: + raise ValueError( + "This model can not be exported as a container image in '{export_format_id}' format. " + "Try exporting the model artifacts by passing a `artifact_destination` argument." + ) + + # Construct request payload + output_config = gca_model_service_compat.ExportModelRequest.OutputConfig( + export_format_id=export_format_id + ) + + if artifact_destination: + output_config.artifact_destination = gca_io_compat.GcsDestination( + output_uri_prefix=artifact_destination + ) + + if image_destination: + output_config.image_destination = gca_io_compat.ContainerRegistryDestination( + output_uri=image_destination + ) + + _LOGGER.log_action_start_against_resource("Exporting", "model", self) + + operation_future = self.api_client.export_model( + name=self.resource_name, output_config=output_config + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Export", "model", self.__class__, operation_future + ) + + # Block before returning + self._wait_on_export(operation_future=operation_future, sync=sync) + + _LOGGER.log_action_completed_against_resource("model", "exported", self) + + return json_format.MessageToDict(operation_future.metadata.output_info._pb) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 220a34637e..441f91ca39 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -143,7 +143,7 @@ def __init__( @classmethod @abc.abstractmethod def _supported_training_schemas(cls) -> Tuple[str]: - """List of supported schemas for this training job""" + """List of supported schemas for this training job.""" pass @@ -211,7 +211,10 @@ def _model_upload_fail_string(self) -> str: @abc.abstractmethod def run(self) -> Optional[models.Model]: - """Runs the training job. Should call _run_job internally""" + """Runs the training job. + + Should call _run_job internally + """ pass @staticmethod @@ -530,7 +533,8 @@ def _run_job( return model def _is_waiting_to_run(self) -> bool: - """Returns True if the Job is pending on upstream tasks False otherwise.""" + """Returns True if the Job is pending on upstream tasks False + otherwise.""" self._raise_future_exception() if self._latest_future: _LOGGER.info( @@ -563,7 +567,7 @@ def get_model(self, sync=True) -> models.Model: model: AI Platform Model produced by this training Raises: - RuntimeError if training failed or if a model was not produced by this training. + RuntimeError: If training failed or if a model was not produced by this training. """ self._assert_has_run() @@ -586,7 +590,7 @@ def _force_get_model(self, sync: bool = True) -> models.Model: model: AI Platform Model produced by this training Raises: - RuntimeError if training failed or if a model was not produced by this training. + RuntimeError: If training failed or if a model was not produced by this training. """ model = self._get_model() @@ -603,7 +607,7 @@ def _get_model(self) -> Optional[models.Model]: Model. None otherwise. Raises: - RuntimeError if Training failed. + RuntimeError: If Training failed. """ self._block_until_complete() @@ -662,19 +666,24 @@ def _raise_failure(self): """Helper method to raise failure if TrainingPipeline fails. Raises: - RuntimeError: If training failed.""" + RuntimeError: If training failed. + """ if self._gca_resource.error.code != code_pb2.OK: raise RuntimeError("Training failed with:\n%s" % self._gca_resource.error) @property def has_failed(self) -> bool: - """Returns True if training has failed. False otherwise.""" + """Returns True if training has failed. + + False otherwise. + """ self._assert_has_run() return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED def _dashboard_uri(self) -> str: - """Helper method to compose the dashboard uri where training can be viewed.""" + """Helper method to compose the dashboard uri where training can be + viewed.""" fields = utils.extract_fields_from_resource_name(self.resource_name) url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/training/{fields.id}?project={fields.project}" return url @@ -762,7 +771,7 @@ def cancel(self) -> None: becomes a job with state set to `CANCELLED`. Raises: - RuntimeError if this TrainingJob has not started running. + RuntimeError: If this TrainingJob has not started running. """ if not self._has_run: raise RuntimeError( @@ -838,10 +847,10 @@ def _timestamped_copy_to_gcs( def _get_python_executable() -> str: """Returns Python executable. - Raises: - EnvironmentError if Python executable is not found. Returns: Python executable to use for setuptools packaging. + Raises: + EnvironmentError: If Python executable is not found. """ python_executable = sys.executable @@ -852,7 +861,8 @@ def _get_python_executable() -> str: class _TrainingScriptPythonPackager: - """Converts a Python script into Python package suitable for aiplatform training. + """Converts a Python script into Python package suitable for aiplatform + training. Copies the script to specified location. @@ -879,7 +889,6 @@ class _TrainingScriptPythonPackager: The package after installed can be executed as: python -m aiplatform_custom_trainer_script.task - """ _TRAINER_FOLDER = "trainer" @@ -917,14 +926,15 @@ def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = Non self.requirements = requirements or [] def make_package(self, package_directory: str) -> str: - """Converts script into a Python package suitable for python module execution. + """Converts script into a Python package suitable for python module + execution. Args: package_directory (str): Directory to build package in. Returns: source_distribution_path (str): Path to built package. Raises: - RunTimeError if package creation fails. + RunTimeError: If package creation fails. """ # The root folder to builder the package in package_path = pathlib.Path(package_directory) @@ -1126,7 +1136,6 @@ class _DistributedTrainingSpec(NamedTuple): accelerator_type='NVIDIA_TESLA_K80' ) ) - """ chief_spec: _MachineSpec = _MachineSpec() @@ -1138,7 +1147,8 @@ class _DistributedTrainingSpec(NamedTuple): def pool_specs( self, ) -> List[Dict[str, Union[int, str, Dict[str, Union[int, str]]]]]: - """Return each pools spec in correct order for AI Platform as a list of dicts. + """Return each pools spec in correct order for AI Platform as a list of + dicts. Also removes specs if they are empty but leaves specs in if there unusual specifications to not break the ordering in AI Platform Training. @@ -1186,7 +1196,7 @@ def chief_worker_pool( accelerator_type (str): Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, - NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_T4 accelerator_count (int): The number of accelerators to attach to a worker replica. @@ -1215,8 +1225,7 @@ def chief_worker_pool( class _CustomTrainingJob(_TrainingJob): - """ABC for Custom Training Pipelines.. - """ + """ABC for Custom Training Pipelines..""" _supported_training_schemas = (schema.training_job.definition.custom_task,) @@ -1448,13 +1457,16 @@ def _prepare_and_validate_run( accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, ) -> Tuple[_DistributedTrainingSpec, Optional[gca_model.Model]]: - """Create worker pool specs and managed model as well validating the run. + """Create worker pool specs and managed model as well validating the + run. Args: model_display_name (str): If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. replica_count (int): The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. If replica_count > 1 the remainder will be @@ -1464,16 +1476,15 @@ def _prepare_and_validate_run( accelerator_type (str): Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, - NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_T4 accelerator_count (int): The number of accelerators to attach to a worker replica. Returns: Worker pools specs and managed model for run. Raises: - RuntimeError if Training job has already been run or model_display_name was - provided but required arguments were not provided in constructor. - + RuntimeError: If Training job has already been run or model_display_name was + provided but required arguments were not provided in constructor. """ if self._is_waiting_to_run(): @@ -1491,6 +1502,9 @@ def _prepare_and_validate_run( """ ) + if self._managed_model.container_spec.image_uri: + model_display_name = model_display_name or self._display_name + "-model" + # validates args and will raise worker_pool_specs = _DistributedTrainingSpec.chief_worker_pool( replica_count=replica_count, @@ -1512,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir( self, worker_pool_specs: _DistributedTrainingSpec, base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1521,6 +1536,9 @@ def _prepare_training_task_inputs_and_output_dir( base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. Returns: Training task inputs and Output directory for custom job. """ @@ -1537,6 +1555,9 @@ def _prepare_training_task_inputs_and_output_dir( "baseOutputDirectory": {"output_uri_prefix": base_output_dir}, } + if service_account: + training_task_inputs["serviceAccount"] = service_account + return training_task_inputs, base_output_dir @property @@ -1555,8 +1576,8 @@ def _model_upload_fail_string(self) -> str: class CustomTrainingJob(_CustomTrainingJob): """Class to launch a Custom Training Job in AI Platform using a script. - Takes a training implementation as a python script and executes that script - in Cloud AI Platform Training. + Takes a training implementation as a python script and executes that + script in Cloud AI Platform Training. """ def __init__( @@ -1782,6 +1803,7 @@ def run( annotation_schema_uri: Optional[str] = None, model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, replica_count: int = 0, @@ -1854,9 +1876,14 @@ def run( If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -1883,7 +1910,7 @@ def run( accelerator_type (str): Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, - NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_T4 accelerator_count (int): The number of accelerators to attach to a worker replica. training_fraction_split (float): @@ -1935,6 +1962,7 @@ def run( managed_model=managed_model, args=args, base_output_dir=base_output_dir, + service_account=service_account, bigquery_destination=bigquery_destination, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, @@ -1960,6 +1988,7 @@ def _run( managed_model: Optional[gca_model.Model] = None, args: Optional[List[Union[str, float, int]]] = None, base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, bigquery_destination: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -1993,6 +2022,9 @@ def _run( base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -2056,7 +2088,7 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir + worker_pool_specs, base_output_dir, service_account ) model = self._run_job( @@ -2077,7 +2109,8 @@ def _run( class CustomContainerTrainingJob(_CustomTrainingJob): - """Class to launch a Custom Training Job in AI Platform using a Container.""" + """Class to launch a Custom Training Job in AI Platform using a + Container.""" def __init__( self, @@ -2299,6 +2332,7 @@ def run( annotation_schema_uri: Optional[str] = None, model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, replica_count: int = 0, @@ -2327,14 +2361,7 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset ( - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ): + dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset]): AI Platform to fit this training against. Custom training script should retrieve datasets through passed in environment variables uris: @@ -2371,9 +2398,14 @@ def run( If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -2400,7 +2432,7 @@ def run( accelerator_type (str): Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, - NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_T4 accelerator_count (int): The number of accelerators to attach to a worker replica. training_fraction_split (float): @@ -2432,7 +2464,7 @@ def run( produce an AI Platform Model. Raises: - RuntimeError if Training job has already been run, staging_bucket has not + RuntimeError: If Training job has already been run, staging_bucket has not been set, or model_display_name was provided but required arguments were not provided in constructor. """ @@ -2451,6 +2483,7 @@ def run( managed_model=managed_model, args=args, base_output_dir=base_output_dir, + service_account=service_account, bigquery_destination=bigquery_destination, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, @@ -2475,6 +2508,7 @@ def _run( managed_model: Optional[gca_model.Model] = None, args: Optional[List[Union[str, float, int]]] = None, base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, bigquery_destination: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -2505,6 +2539,9 @@ def _run( base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. bigquery_destination (str): The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -2561,7 +2598,7 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir + worker_pool_specs, base_output_dir, service_account ) model = self._run_job( @@ -2800,7 +2837,7 @@ def run( produce an AI Platform Model. Raises: - RuntimeError if Training job has already been run or is waiting to run. + RuntimeError: If Training job has already been run or is waiting to run. """ if self._is_waiting_to_run(): @@ -3330,10 +3367,11 @@ def _model_upload_fail_string(self) -> str: class CustomPythonPackageTrainingJob(_CustomTrainingJob): - """Class to launch a Custom Training Job in AI Platform using a Python Package. + """Class to launch a Custom Training Job in AI Platform using a Python + Package. - Takes a training implementation as a python package and executes that package - in Cloud AI Platform Training. + Takes a training implementation as a python package and executes + that package in Cloud AI Platform Training. """ def __init__( @@ -3564,6 +3602,7 @@ def run( annotation_schema_uri: Optional[str] = None, model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, replica_count: int = 0, @@ -3592,14 +3631,7 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset ( - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ): + dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset,]): AI Platform to fit this training against. Custom training script should retrieve datasets through passed in environement variables uris: @@ -3636,9 +3668,14 @@ def run( If the script produces a managed AI Platform Model. The display name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -3665,7 +3702,7 @@ def run( accelerator_type (str): Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, - NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + NVIDIA_TESLA_T4 accelerator_count (int): The number of accelerators to attach to a worker replica. training_fraction_split (float): @@ -3711,6 +3748,7 @@ def run( managed_model=managed_model, args=args, base_output_dir=base_output_dir, + service_account=service_account, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, @@ -3735,6 +3773,7 @@ def _run( managed_model: Optional[gca_model.Model] = None, args: Optional[List[Union[str, float, int]]] = None, base_output_dir: Optional[str] = None, + service_account: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -3766,6 +3805,9 @@ def _run( base_output_dir (str): GCS output directory of job. If not provided a timestamped directory in the staging directory will be used. + service_account (str): + Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. training_fraction_split (float): The fraction of the input data that is to be used to train the Model. @@ -3808,7 +3850,7 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir + worker_pool_specs, base_output_dir, service_account ) model = self._run_job( @@ -4263,7 +4305,7 @@ def run( model: The trained AI Platform Model resource. Raises: - RuntimeError if Training job has already been run or is waiting to run. + RuntimeError: If Training job has already been run or is waiting to run. """ if self._is_waiting_to_run(): diff --git a/google/cloud/aiplatform/training_utils.py b/google/cloud/aiplatform/training_utils.py index a93ecaa1ce..fea60c5005 100644 --- a/google/cloud/aiplatform/training_utils.py +++ b/google/cloud/aiplatform/training_utils.py @@ -22,7 +22,7 @@ class EnvironmentVariables: - """Passes on OS' environment variables""" + """Passes on OS' environment variables.""" @property def training_data_uri(self) -> Optional[str]: diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index c90bc4c3f0..02f9fd8388 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -79,7 +79,8 @@ def _match_to_fields(match: Match) -> Optional[Fields]: - """Normalize RegEx groups from resource name pattern Match to class Fields""" + """Normalize RegEx groups from resource name pattern Match to class + Fields.""" if not match: return None @@ -92,15 +93,15 @@ def _match_to_fields(match: Match) -> Optional[Fields]: def validate_id(resource_id: str) -> bool: - """Validate int64 resource ID number""" + """Validate int64 resource ID number.""" return bool(RESOURCE_ID_PATTERN.match(resource_id)) def extract_fields_from_resource_name( resource_name: str, resource_noun: Optional[str] = None ) -> Optional[Fields]: - """Validates and returns extracted fields from a fully-qualified resource name. - Returns None if name is invalid. + """Validates and returns extracted fields from a fully-qualified resource + name. Returns None if name is invalid. Args: resource_name (str): @@ -135,8 +136,7 @@ def full_resource_name( project: Optional[str] = None, location: Optional[str] = None, ) -> str: - """ - Returns fully qualified resource name. + """Returns fully qualified resource name. Args: resource_name (str): @@ -223,7 +223,7 @@ def validate_project(project: str) -> bool: # TODO(b/172932277) verify display name only contains utf-8 chars def validate_display_name(display_name: str): - """Verify display name is at most 128 chars + """Verify display name is at most 128 chars. Args: display_name: display name to verify @@ -259,7 +259,8 @@ def validate_region(region: str) -> bool: def validate_accelerator_type(accelerator_type: str) -> bool: - """Validates user provided accelerator_type string for training and prediction + """Validates user provided accelerator_type string for training and + prediction. Args: accelerator_type (str): @@ -313,7 +314,8 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona class ClientWithOverride: class WrappedClient: - """Wrapper class for client that creates client at API invocation time.""" + """Wrapper class for client that creates client at API invocation + time.""" def __init__( self, @@ -324,14 +326,15 @@ def __init__( ): """Stores parameters needed to instantiate client. - client_class (AiPlatformServiceClient): - Required. Class of the client to use. - client_options (client_options.ClientOptions): - Required. Client options to pass to client. - client_info (gapic_v1.client_info.ClientInfo): - Required. Client info to pass to client. - credentials (auth_credentials.credentials): - Optional. Client credentials to pass to client. + Args: + client_class (AiPlatformServiceClient): + Required. Class of the client to use. + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. """ self._client_class = client_class @@ -373,12 +376,13 @@ def __init__( ): """Stores parameters needed to instantiate client. - client_options (client_options.ClientOptions): - Required. Client options to pass to client. - client_info (gapic_v1.client_info.ClientInfo): - Required. Client info to pass to client. - credentials (auth_credentials.credentials): - Optional. Client credentials to pass to client. + Args: + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. """ self._clients = { @@ -423,7 +427,7 @@ class EndpointClientWithOverride(ClientWithOverride): ) -class JobpointClientWithOverride(ClientWithOverride): +class JobClientWithOverride(ClientWithOverride): _is_temporary = True _default_version = compat.DEFAULT_VERSION _version_map = ( @@ -471,7 +475,7 @@ class MetadataClientWithOverride(ClientWithOverride): "AiPlatformServiceClientWithOverride", DatasetClientWithOverride, EndpointClientWithOverride, - JobpointClientWithOverride, + JobClientWithOverride, ModelClientWithOverride, PipelineClientWithOverride, PredictionClientWithOverride, diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_metadata.json b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_metadata.json new file mode 100644 index 0000000000..0ae909d6ea --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1.schema.predict.instance_v1", + "protoPackage": "google.cloud.aiplatform.v1.schema.predict.instance", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_metadata.json b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_metadata.json new file mode 100644 index 0000000000..edfffb441b --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1.schema.predict.params_v1", + "protoPackage": "google.cloud.aiplatform.v1.schema.predict.params", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_metadata.json b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_metadata.json new file mode 100644 index 0000000000..ba1d67a00c --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1.schema.predict.prediction_v1", + "protoPackage": "google.cloud.aiplatform.v1.schema.predict.prediction", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_metadata.json b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_metadata.json new file mode 100644 index 0000000000..620ff75f05 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1.schema.trainingjob.definition_v1", + "protoPackage": "google.cloud.aiplatform.v1.schema.trainingjob.definition", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_metadata.json b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_metadata.json new file mode 100644 index 0000000000..38379e8208 --- /dev/null +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1", + "protoPackage": "google.cloud.aiplatform.v1beta1.schema.predict.instance", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_metadata.json b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_metadata.json new file mode 100644 index 0000000000..6b925dd9dc --- /dev/null +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1", + "protoPackage": "google.cloud.aiplatform.v1beta1.schema.predict.params", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_metadata.json b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_metadata.json new file mode 100644 index 0000000000..99d3dc6402 --- /dev/null +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1", + "protoPackage": "google.cloud.aiplatform.v1beta1.schema.predict.prediction", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_metadata.json b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_metadata.json new file mode 100644 index 0000000000..6de794c90a --- /dev/null +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_metadata.json @@ -0,0 +1,7 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1", + "protoPackage": "google.cloud.aiplatform.v1beta1.schema.trainingjob.definition", + "schema": "1.0" +} diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json new file mode 100644 index 0000000000..0abed0fd70 --- /dev/null +++ b/google/cloud/aiplatform_v1/gapic_metadata.json @@ -0,0 +1,721 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform_v1", + "protoPackage": "google.cloud.aiplatform.v1", + "schema": "1.0", + "services": { + "DatasetService": { + "clients": { + "grpc": { + "libraryClient": "DatasetServiceClient", + "rpcs": { + "CreateDataset": { + "methods": [ + "create_dataset" + ] + }, + "DeleteDataset": { + "methods": [ + "delete_dataset" + ] + }, + "ExportData": { + "methods": [ + "export_data" + ] + }, + "GetAnnotationSpec": { + "methods": [ + "get_annotation_spec" + ] + }, + "GetDataset": { + "methods": [ + "get_dataset" + ] + }, + "ImportData": { + "methods": [ + "import_data" + ] + }, + "ListAnnotations": { + "methods": [ + "list_annotations" + ] + }, + "ListDataItems": { + "methods": [ + "list_data_items" + ] + }, + "ListDatasets": { + "methods": [ + "list_datasets" + ] + }, + "UpdateDataset": { + "methods": [ + "update_dataset" + ] + } + } + }, + "grpc-async": { + "libraryClient": "DatasetServiceAsyncClient", + "rpcs": { + "CreateDataset": { + "methods": [ + "create_dataset" + ] + }, + "DeleteDataset": { + "methods": [ + "delete_dataset" + ] + }, + "ExportData": { + "methods": [ + "export_data" + ] + }, + "GetAnnotationSpec": { + "methods": [ + "get_annotation_spec" + ] + }, + "GetDataset": { + "methods": [ + "get_dataset" + ] + }, + "ImportData": { + "methods": [ + "import_data" + ] + }, + "ListAnnotations": { + "methods": [ + "list_annotations" + ] + }, + "ListDataItems": { + "methods": [ + "list_data_items" + ] + }, + "ListDatasets": { + "methods": [ + "list_datasets" + ] + }, + "UpdateDataset": { + "methods": [ + "update_dataset" + ] + } + } + } + } + }, + "EndpointService": { + "clients": { + "grpc": { + "libraryClient": "EndpointServiceClient", + "rpcs": { + "CreateEndpoint": { + "methods": [ + "create_endpoint" + ] + }, + "DeleteEndpoint": { + "methods": [ + "delete_endpoint" + ] + }, + "DeployModel": { + "methods": [ + "deploy_model" + ] + }, + "GetEndpoint": { + "methods": [ + "get_endpoint" + ] + }, + "ListEndpoints": { + "methods": [ + "list_endpoints" + ] + }, + "UndeployModel": { + "methods": [ + "undeploy_model" + ] + }, + "UpdateEndpoint": { + "methods": [ + "update_endpoint" + ] + } + } + }, + "grpc-async": { + "libraryClient": "EndpointServiceAsyncClient", + "rpcs": { + "CreateEndpoint": { + "methods": [ + "create_endpoint" + ] + }, + "DeleteEndpoint": { + "methods": [ + "delete_endpoint" + ] + }, + "DeployModel": { + "methods": [ + "deploy_model" + ] + }, + "GetEndpoint": { + "methods": [ + "get_endpoint" + ] + }, + "ListEndpoints": { + "methods": [ + "list_endpoints" + ] + }, + "UndeployModel": { + "methods": [ + "undeploy_model" + ] + }, + "UpdateEndpoint": { + "methods": [ + "update_endpoint" + ] + } + } + } + } + }, + "JobService": { + "clients": { + "grpc": { + "libraryClient": "JobServiceClient", + "rpcs": { + "CancelBatchPredictionJob": { + "methods": [ + "cancel_batch_prediction_job" + ] + }, + "CancelCustomJob": { + "methods": [ + "cancel_custom_job" + ] + }, + "CancelDataLabelingJob": { + "methods": [ + "cancel_data_labeling_job" + ] + }, + "CancelHyperparameterTuningJob": { + "methods": [ + "cancel_hyperparameter_tuning_job" + ] + }, + "CreateBatchPredictionJob": { + "methods": [ + "create_batch_prediction_job" + ] + }, + "CreateCustomJob": { + "methods": [ + "create_custom_job" + ] + }, + "CreateDataLabelingJob": { + "methods": [ + "create_data_labeling_job" + ] + }, + "CreateHyperparameterTuningJob": { + "methods": [ + "create_hyperparameter_tuning_job" + ] + }, + "DeleteBatchPredictionJob": { + "methods": [ + "delete_batch_prediction_job" + ] + }, + "DeleteCustomJob": { + "methods": [ + "delete_custom_job" + ] + }, + "DeleteDataLabelingJob": { + "methods": [ + "delete_data_labeling_job" + ] + }, + "DeleteHyperparameterTuningJob": { + "methods": [ + "delete_hyperparameter_tuning_job" + ] + }, + "GetBatchPredictionJob": { + "methods": [ + "get_batch_prediction_job" + ] + }, + "GetCustomJob": { + "methods": [ + "get_custom_job" + ] + }, + "GetDataLabelingJob": { + "methods": [ + "get_data_labeling_job" + ] + }, + "GetHyperparameterTuningJob": { + "methods": [ + "get_hyperparameter_tuning_job" + ] + }, + "ListBatchPredictionJobs": { + "methods": [ + "list_batch_prediction_jobs" + ] + }, + "ListCustomJobs": { + "methods": [ + "list_custom_jobs" + ] + }, + "ListDataLabelingJobs": { + "methods": [ + "list_data_labeling_jobs" + ] + }, + "ListHyperparameterTuningJobs": { + "methods": [ + "list_hyperparameter_tuning_jobs" + ] + } + } + }, + "grpc-async": { + "libraryClient": "JobServiceAsyncClient", + "rpcs": { + "CancelBatchPredictionJob": { + "methods": [ + "cancel_batch_prediction_job" + ] + }, + "CancelCustomJob": { + "methods": [ + "cancel_custom_job" + ] + }, + "CancelDataLabelingJob": { + "methods": [ + "cancel_data_labeling_job" + ] + }, + "CancelHyperparameterTuningJob": { + "methods": [ + "cancel_hyperparameter_tuning_job" + ] + }, + "CreateBatchPredictionJob": { + "methods": [ + "create_batch_prediction_job" + ] + }, + "CreateCustomJob": { + "methods": [ + "create_custom_job" + ] + }, + "CreateDataLabelingJob": { + "methods": [ + "create_data_labeling_job" + ] + }, + "CreateHyperparameterTuningJob": { + "methods": [ + "create_hyperparameter_tuning_job" + ] + }, + "DeleteBatchPredictionJob": { + "methods": [ + "delete_batch_prediction_job" + ] + }, + "DeleteCustomJob": { + "methods": [ + "delete_custom_job" + ] + }, + "DeleteDataLabelingJob": { + "methods": [ + "delete_data_labeling_job" + ] + }, + "DeleteHyperparameterTuningJob": { + "methods": [ + "delete_hyperparameter_tuning_job" + ] + }, + "GetBatchPredictionJob": { + "methods": [ + "get_batch_prediction_job" + ] + }, + "GetCustomJob": { + "methods": [ + "get_custom_job" + ] + }, + "GetDataLabelingJob": { + "methods": [ + "get_data_labeling_job" + ] + }, + "GetHyperparameterTuningJob": { + "methods": [ + "get_hyperparameter_tuning_job" + ] + }, + "ListBatchPredictionJobs": { + "methods": [ + "list_batch_prediction_jobs" + ] + }, + "ListCustomJobs": { + "methods": [ + "list_custom_jobs" + ] + }, + "ListDataLabelingJobs": { + "methods": [ + "list_data_labeling_jobs" + ] + }, + "ListHyperparameterTuningJobs": { + "methods": [ + "list_hyperparameter_tuning_jobs" + ] + } + } + } + } + }, + "MigrationService": { + "clients": { + "grpc": { + "libraryClient": "MigrationServiceClient", + "rpcs": { + "BatchMigrateResources": { + "methods": [ + "batch_migrate_resources" + ] + }, + "SearchMigratableResources": { + "methods": [ + "search_migratable_resources" + ] + } + } + }, + "grpc-async": { + "libraryClient": "MigrationServiceAsyncClient", + "rpcs": { + "BatchMigrateResources": { + "methods": [ + "batch_migrate_resources" + ] + }, + "SearchMigratableResources": { + "methods": [ + "search_migratable_resources" + ] + } + } + } + } + }, + "ModelService": { + "clients": { + "grpc": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "DeleteModel": { + "methods": [ + "delete_model" + ] + }, + "ExportModel": { + "methods": [ + "export_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetModelEvaluation": { + "methods": [ + "get_model_evaluation" + ] + }, + "GetModelEvaluationSlice": { + "methods": [ + "get_model_evaluation_slice" + ] + }, + "ListModelEvaluationSlices": { + "methods": [ + "list_model_evaluation_slices" + ] + }, + "ListModelEvaluations": { + "methods": [ + "list_model_evaluations" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "UpdateModel": { + "methods": [ + "update_model" + ] + }, + "UploadModel": { + "methods": [ + "upload_model" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ModelServiceAsyncClient", + "rpcs": { + "DeleteModel": { + "methods": [ + "delete_model" + ] + }, + "ExportModel": { + "methods": [ + "export_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetModelEvaluation": { + "methods": [ + "get_model_evaluation" + ] + }, + "GetModelEvaluationSlice": { + "methods": [ + "get_model_evaluation_slice" + ] + }, + "ListModelEvaluationSlices": { + "methods": [ + "list_model_evaluation_slices" + ] + }, + "ListModelEvaluations": { + "methods": [ + "list_model_evaluations" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "UpdateModel": { + "methods": [ + "update_model" + ] + }, + "UploadModel": { + "methods": [ + "upload_model" + ] + } + } + } + } + }, + "PipelineService": { + "clients": { + "grpc": { + "libraryClient": "PipelineServiceClient", + "rpcs": { + "CancelTrainingPipeline": { + "methods": [ + "cancel_training_pipeline" + ] + }, + "CreateTrainingPipeline": { + "methods": [ + "create_training_pipeline" + ] + }, + "DeleteTrainingPipeline": { + "methods": [ + "delete_training_pipeline" + ] + }, + "GetTrainingPipeline": { + "methods": [ + "get_training_pipeline" + ] + }, + "ListTrainingPipelines": { + "methods": [ + "list_training_pipelines" + ] + } + } + }, + "grpc-async": { + "libraryClient": "PipelineServiceAsyncClient", + "rpcs": { + "CancelTrainingPipeline": { + "methods": [ + "cancel_training_pipeline" + ] + }, + "CreateTrainingPipeline": { + "methods": [ + "create_training_pipeline" + ] + }, + "DeleteTrainingPipeline": { + "methods": [ + "delete_training_pipeline" + ] + }, + "GetTrainingPipeline": { + "methods": [ + "get_training_pipeline" + ] + }, + "ListTrainingPipelines": { + "methods": [ + "list_training_pipelines" + ] + } + } + } + } + }, + "PredictionService": { + "clients": { + "grpc": { + "libraryClient": "PredictionServiceClient", + "rpcs": { + "Predict": { + "methods": [ + "predict" + ] + } + } + }, + "grpc-async": { + "libraryClient": "PredictionServiceAsyncClient", + "rpcs": { + "Predict": { + "methods": [ + "predict" + ] + } + } + } + } + }, + "SpecialistPoolService": { + "clients": { + "grpc": { + "libraryClient": "SpecialistPoolServiceClient", + "rpcs": { + "CreateSpecialistPool": { + "methods": [ + "create_specialist_pool" + ] + }, + "DeleteSpecialistPool": { + "methods": [ + "delete_specialist_pool" + ] + }, + "GetSpecialistPool": { + "methods": [ + "get_specialist_pool" + ] + }, + "ListSpecialistPools": { + "methods": [ + "list_specialist_pools" + ] + }, + "UpdateSpecialistPool": { + "methods": [ + "update_specialist_pool" + ] + } + } + }, + "grpc-async": { + "libraryClient": "SpecialistPoolServiceAsyncClient", + "rpcs": { + "CreateSpecialistPool": { + "methods": [ + "create_specialist_pool" + ] + }, + "DeleteSpecialistPool": { + "methods": [ + "delete_specialist_pool" + ] + }, + "GetSpecialistPool": { + "methods": [ + "get_specialist_pool" + ] + }, + "ListSpecialistPools": { + "methods": [ + "list_specialist_pools" + ] + }, + "UpdateSpecialistPool": { + "methods": [ + "update_specialist_pool" + ] + } + } + } + } + } + } +} diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 0d6e0fdbd6..042e3402d1 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -180,32 +180,32 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 3f605a0fcb..094b82e45c 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -30,6 +30,7 @@ from .services.pipeline_service import PipelineServiceClient from .services.prediction_service import PredictionServiceClient from .services.specialist_pool_service import SpecialistPoolServiceClient +from .services.tensorboard_service import TensorboardServiceClient from .services.vizier_service import VizierServiceClient from .types.accelerator_type import AcceleratorType from .types.annotation import Annotation @@ -115,7 +116,6 @@ from .types.featurestore_online_service import FeatureValueList from .types.featurestore_online_service import ReadFeatureValuesRequest from .types.featurestore_online_service import ReadFeatureValuesResponse -from .types.featurestore_online_service import ReadSetting from .types.featurestore_online_service import StreamingReadFeatureValuesRequest from .types.featurestore_service import BatchCreateFeaturesOperationMetadata from .types.featurestore_service import BatchCreateFeaturesRequest @@ -133,6 +133,9 @@ from .types.featurestore_service import DeleteFeatureRequest from .types.featurestore_service import DeleteFeaturestoreRequest from .types.featurestore_service import DestinationFeatureSetting +from .types.featurestore_service import ExportFeatureValuesOperationMetadata +from .types.featurestore_service import ExportFeatureValuesRequest +from .types.featurestore_service import ExportFeatureValuesResponse from .types.featurestore_service import FeatureValueDestination from .types.featurestore_service import GetEntityTypeRequest from .types.featurestore_service import GetFeatureRequest @@ -323,10 +326,20 @@ from .types.model_service import UploadModelResponse from .types.operation import DeleteOperationMetadata from .types.operation import GenericOperationMetadata +from .types.pipeline_job import PipelineJob +from .types.pipeline_job import PipelineJobDetail +from .types.pipeline_job import PipelineTaskDetail +from .types.pipeline_job import PipelineTaskExecutorDetail +from .types.pipeline_service import CancelPipelineJobRequest from .types.pipeline_service import CancelTrainingPipelineRequest +from .types.pipeline_service import CreatePipelineJobRequest from .types.pipeline_service import CreateTrainingPipelineRequest +from .types.pipeline_service import DeletePipelineJobRequest from .types.pipeline_service import DeleteTrainingPipelineRequest +from .types.pipeline_service import GetPipelineJobRequest from .types.pipeline_service import GetTrainingPipelineRequest +from .types.pipeline_service import ListPipelineJobsRequest +from .types.pipeline_service import ListPipelineJobsResponse from .types.pipeline_service import ListTrainingPipelinesRequest from .types.pipeline_service import ListTrainingPipelinesResponse from .types.pipeline_state import PipelineState @@ -347,6 +360,50 @@ from .types.study import Study from .types.study import StudySpec from .types.study import Trial +from .types.tensorboard import Tensorboard +from .types.tensorboard_data import Scalar +from .types.tensorboard_data import TensorboardBlob +from .types.tensorboard_data import TensorboardBlobSequence +from .types.tensorboard_data import TensorboardTensor +from .types.tensorboard_data import TimeSeriesData +from .types.tensorboard_data import TimeSeriesDataPoint +from .types.tensorboard_experiment import TensorboardExperiment +from .types.tensorboard_run import TensorboardRun +from .types.tensorboard_service import CreateTensorboardExperimentRequest +from .types.tensorboard_service import CreateTensorboardOperationMetadata +from .types.tensorboard_service import CreateTensorboardRequest +from .types.tensorboard_service import CreateTensorboardRunRequest +from .types.tensorboard_service import CreateTensorboardTimeSeriesRequest +from .types.tensorboard_service import DeleteTensorboardExperimentRequest +from .types.tensorboard_service import DeleteTensorboardRequest +from .types.tensorboard_service import DeleteTensorboardRunRequest +from .types.tensorboard_service import DeleteTensorboardTimeSeriesRequest +from .types.tensorboard_service import ExportTensorboardTimeSeriesDataRequest +from .types.tensorboard_service import ExportTensorboardTimeSeriesDataResponse +from .types.tensorboard_service import GetTensorboardExperimentRequest +from .types.tensorboard_service import GetTensorboardRequest +from .types.tensorboard_service import GetTensorboardRunRequest +from .types.tensorboard_service import GetTensorboardTimeSeriesRequest +from .types.tensorboard_service import ListTensorboardExperimentsRequest +from .types.tensorboard_service import ListTensorboardExperimentsResponse +from .types.tensorboard_service import ListTensorboardRunsRequest +from .types.tensorboard_service import ListTensorboardRunsResponse +from .types.tensorboard_service import ListTensorboardTimeSeriesRequest +from .types.tensorboard_service import ListTensorboardTimeSeriesResponse +from .types.tensorboard_service import ListTensorboardsRequest +from .types.tensorboard_service import ListTensorboardsResponse +from .types.tensorboard_service import ReadTensorboardBlobDataRequest +from .types.tensorboard_service import ReadTensorboardBlobDataResponse +from .types.tensorboard_service import ReadTensorboardTimeSeriesDataRequest +from .types.tensorboard_service import ReadTensorboardTimeSeriesDataResponse +from .types.tensorboard_service import UpdateTensorboardExperimentRequest +from .types.tensorboard_service import UpdateTensorboardOperationMetadata +from .types.tensorboard_service import UpdateTensorboardRequest +from .types.tensorboard_service import UpdateTensorboardRunRequest +from .types.tensorboard_service import UpdateTensorboardTimeSeriesRequest +from .types.tensorboard_service import WriteTensorboardRunDataRequest +from .types.tensorboard_service import WriteTensorboardRunDataResponse +from .types.tensorboard_time_series import TensorboardTimeSeries from .types.training_pipeline import FilterSplit from .types.training_pipeline import FractionSplit from .types.training_pipeline import InputDataConfig @@ -358,6 +415,7 @@ from .types.types import Int64Array from .types.types import StringArray from .types.user_action_reference import UserActionReference +from .types.value import Value from .types.vizier_service import AddTrialMeasurementRequest from .types.vizier_service import CheckTrialEarlyStoppingStateMetatdata from .types.vizier_service import CheckTrialEarlyStoppingStateRequest @@ -417,6 +475,7 @@ "CancelCustomJobRequest", "CancelDataLabelingJobRequest", "CancelHyperparameterTuningJobRequest", + "CancelPipelineJobRequest", "CancelTrainingPipelineRequest", "CheckTrialEarlyStoppingStateMetatdata", "CheckTrialEarlyStoppingStateRequest", @@ -451,9 +510,15 @@ "CreateMetadataStoreOperationMetadata", "CreateMetadataStoreRequest", "CreateModelDeploymentMonitoringJobRequest", + "CreatePipelineJobRequest", "CreateSpecialistPoolOperationMetadata", "CreateSpecialistPoolRequest", "CreateStudyRequest", + "CreateTensorboardExperimentRequest", + "CreateTensorboardOperationMetadata", + "CreateTensorboardRequest", + "CreateTensorboardRunRequest", + "CreateTensorboardTimeSeriesRequest", "CreateTrainingPipelineRequest", "CreateTrialRequest", "CsvDestination", @@ -482,8 +547,13 @@ "DeleteModelDeploymentMonitoringJobRequest", "DeleteModelRequest", "DeleteOperationMetadata", + "DeletePipelineJobRequest", "DeleteSpecialistPoolRequest", "DeleteStudyRequest", + "DeleteTensorboardExperimentRequest", + "DeleteTensorboardRequest", + "DeleteTensorboardRunRequest", + "DeleteTensorboardTimeSeriesRequest", "DeleteTrainingPipelineRequest", "DeleteTrialRequest", "DeployIndexOperationMetadata", @@ -519,9 +589,14 @@ "ExportDataOperationMetadata", "ExportDataRequest", "ExportDataResponse", + "ExportFeatureValuesOperationMetadata", + "ExportFeatureValuesRequest", + "ExportFeatureValuesResponse", "ExportModelOperationMetadata", "ExportModelRequest", "ExportModelResponse", + "ExportTensorboardTimeSeriesDataRequest", + "ExportTensorboardTimeSeriesDataResponse", "Feature", "FeatureNoiseSigma", "FeatureSelector", @@ -559,8 +634,13 @@ "GetModelEvaluationRequest", "GetModelEvaluationSliceRequest", "GetModelRequest", + "GetPipelineJobRequest", "GetSpecialistPoolRequest", "GetStudyRequest", + "GetTensorboardExperimentRequest", + "GetTensorboardRequest", + "GetTensorboardRunRequest", + "GetTensorboardTimeSeriesRequest", "GetTrainingPipelineRequest", "GetTrialRequest", "HyperparameterTuningJob", @@ -629,10 +709,20 @@ "ListModelsResponse", "ListOptimalTrialsRequest", "ListOptimalTrialsResponse", + "ListPipelineJobsRequest", + "ListPipelineJobsResponse", "ListSpecialistPoolsRequest", "ListSpecialistPoolsResponse", "ListStudiesRequest", "ListStudiesResponse", + "ListTensorboardExperimentsRequest", + "ListTensorboardExperimentsResponse", + "ListTensorboardRunsRequest", + "ListTensorboardRunsResponse", + "ListTensorboardTimeSeriesRequest", + "ListTensorboardTimeSeriesResponse", + "ListTensorboardsRequest", + "ListTensorboardsResponse", "ListTrainingPipelinesRequest", "ListTrainingPipelinesResponse", "ListTrialsRequest", @@ -663,8 +753,12 @@ "ModelServiceClient", "NearestNeighborSearchOperationMetadata", "PauseModelDeploymentMonitoringJobRequest", + "PipelineJob", + "PipelineJobDetail", "PipelineServiceClient", "PipelineState", + "PipelineTaskDetail", + "PipelineTaskExecutorDetail", "Port", "PredefinedSplit", "PredictRequest", @@ -677,12 +771,16 @@ "QueryExecutionInputsAndOutputsRequest", "ReadFeatureValuesRequest", "ReadFeatureValuesResponse", - "ReadSetting", + "ReadTensorboardBlobDataRequest", + "ReadTensorboardBlobDataResponse", + "ReadTensorboardTimeSeriesDataRequest", + "ReadTensorboardTimeSeriesDataResponse", "ResourcesConsumed", "ResumeModelDeploymentMonitoringJobRequest", "SampleConfig", "SampledShapleyAttribution", "SamplingStrategy", + "Scalar", "Scheduling", "SearchFeaturesRequest", "SearchFeaturesResponse", @@ -702,7 +800,17 @@ "SuggestTrialsRequest", "SuggestTrialsResponse", "TFRecordDestination", + "Tensorboard", + "TensorboardBlob", + "TensorboardBlobSequence", + "TensorboardExperiment", + "TensorboardRun", + "TensorboardServiceClient", + "TensorboardTensor", + "TensorboardTimeSeries", "ThresholdConfig", + "TimeSeriesData", + "TimeSeriesDataPoint", "TimestampSplit", "TrainingConfig", "TrainingPipeline", @@ -730,12 +838,20 @@ "UpdateModelRequest", "UpdateSpecialistPoolOperationMetadata", "UpdateSpecialistPoolRequest", + "UpdateTensorboardExperimentRequest", + "UpdateTensorboardOperationMetadata", + "UpdateTensorboardRequest", + "UpdateTensorboardRunRequest", + "UpdateTensorboardTimeSeriesRequest", "UploadModelOperationMetadata", "UploadModelRequest", "UploadModelResponse", "UserActionReference", + "Value", "VizierServiceClient", "WorkerPoolSpec", + "WriteTensorboardRunDataRequest", + "WriteTensorboardRunDataResponse", "XraiAttribution", "MetadataServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json new file mode 100644 index 0000000000..605e95582d --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -0,0 +1,1949 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.aiplatform_v1beta1", + "protoPackage": "google.cloud.aiplatform.v1beta1", + "schema": "1.0", + "services": { + "DatasetService": { + "clients": { + "grpc": { + "libraryClient": "DatasetServiceClient", + "rpcs": { + "CreateDataset": { + "methods": [ + "create_dataset" + ] + }, + "DeleteDataset": { + "methods": [ + "delete_dataset" + ] + }, + "ExportData": { + "methods": [ + "export_data" + ] + }, + "GetAnnotationSpec": { + "methods": [ + "get_annotation_spec" + ] + }, + "GetDataset": { + "methods": [ + "get_dataset" + ] + }, + "ImportData": { + "methods": [ + "import_data" + ] + }, + "ListAnnotations": { + "methods": [ + "list_annotations" + ] + }, + "ListDataItems": { + "methods": [ + "list_data_items" + ] + }, + "ListDatasets": { + "methods": [ + "list_datasets" + ] + }, + "UpdateDataset": { + "methods": [ + "update_dataset" + ] + } + } + }, + "grpc-async": { + "libraryClient": "DatasetServiceAsyncClient", + "rpcs": { + "CreateDataset": { + "methods": [ + "create_dataset" + ] + }, + "DeleteDataset": { + "methods": [ + "delete_dataset" + ] + }, + "ExportData": { + "methods": [ + "export_data" + ] + }, + "GetAnnotationSpec": { + "methods": [ + "get_annotation_spec" + ] + }, + "GetDataset": { + "methods": [ + "get_dataset" + ] + }, + "ImportData": { + "methods": [ + "import_data" + ] + }, + "ListAnnotations": { + "methods": [ + "list_annotations" + ] + }, + "ListDataItems": { + "methods": [ + "list_data_items" + ] + }, + "ListDatasets": { + "methods": [ + "list_datasets" + ] + }, + "UpdateDataset": { + "methods": [ + "update_dataset" + ] + } + } + } + } + }, + "EndpointService": { + "clients": { + "grpc": { + "libraryClient": "EndpointServiceClient", + "rpcs": { + "CreateEndpoint": { + "methods": [ + "create_endpoint" + ] + }, + "DeleteEndpoint": { + "methods": [ + "delete_endpoint" + ] + }, + "DeployModel": { + "methods": [ + "deploy_model" + ] + }, + "GetEndpoint": { + "methods": [ + "get_endpoint" + ] + }, + "ListEndpoints": { + "methods": [ + "list_endpoints" + ] + }, + "UndeployModel": { + "methods": [ + "undeploy_model" + ] + }, + "UpdateEndpoint": { + "methods": [ + "update_endpoint" + ] + } + } + }, + "grpc-async": { + "libraryClient": "EndpointServiceAsyncClient", + "rpcs": { + "CreateEndpoint": { + "methods": [ + "create_endpoint" + ] + }, + "DeleteEndpoint": { + "methods": [ + "delete_endpoint" + ] + }, + "DeployModel": { + "methods": [ + "deploy_model" + ] + }, + "GetEndpoint": { + "methods": [ + "get_endpoint" + ] + }, + "ListEndpoints": { + "methods": [ + "list_endpoints" + ] + }, + "UndeployModel": { + "methods": [ + "undeploy_model" + ] + }, + "UpdateEndpoint": { + "methods": [ + "update_endpoint" + ] + } + } + } + } + }, + "FeaturestoreOnlineServingService": { + "clients": { + "grpc": { + "libraryClient": "FeaturestoreOnlineServingServiceClient", + "rpcs": { + "ReadFeatureValues": { + "methods": [ + "read_feature_values" + ] + }, + "StreamingReadFeatureValues": { + "methods": [ + "streaming_read_feature_values" + ] + } + } + }, + "grpc-async": { + "libraryClient": "FeaturestoreOnlineServingServiceAsyncClient", + "rpcs": { + "ReadFeatureValues": { + "methods": [ + "read_feature_values" + ] + }, + "StreamingReadFeatureValues": { + "methods": [ + "streaming_read_feature_values" + ] + } + } + } + } + }, + "FeaturestoreService": { + "clients": { + "grpc": { + "libraryClient": "FeaturestoreServiceClient", + "rpcs": { + "BatchCreateFeatures": { + "methods": [ + "batch_create_features" + ] + }, + "BatchReadFeatureValues": { + "methods": [ + "batch_read_feature_values" + ] + }, + "CreateEntityType": { + "methods": [ + "create_entity_type" + ] + }, + "CreateFeature": { + "methods": [ + "create_feature" + ] + }, + "CreateFeaturestore": { + "methods": [ + "create_featurestore" + ] + }, + "DeleteEntityType": { + "methods": [ + "delete_entity_type" + ] + }, + "DeleteFeature": { + "methods": [ + "delete_feature" + ] + }, + "DeleteFeaturestore": { + "methods": [ + "delete_featurestore" + ] + }, + "ExportFeatureValues": { + "methods": [ + "export_feature_values" + ] + }, + "GetEntityType": { + "methods": [ + "get_entity_type" + ] + }, + "GetFeature": { + "methods": [ + "get_feature" + ] + }, + "GetFeaturestore": { + "methods": [ + "get_featurestore" + ] + }, + "ImportFeatureValues": { + "methods": [ + "import_feature_values" + ] + }, + "ListEntityTypes": { + "methods": [ + "list_entity_types" + ] + }, + "ListFeatures": { + "methods": [ + "list_features" + ] + }, + "ListFeaturestores": { + "methods": [ + "list_featurestores" + ] + }, + "SearchFeatures": { + "methods": [ + "search_features" + ] + }, + "UpdateEntityType": { + "methods": [ + "update_entity_type" + ] + }, + "UpdateFeature": { + "methods": [ + "update_feature" + ] + }, + "UpdateFeaturestore": { + "methods": [ + "update_featurestore" + ] + } + } + }, + "grpc-async": { + "libraryClient": "FeaturestoreServiceAsyncClient", + "rpcs": { + "BatchCreateFeatures": { + "methods": [ + "batch_create_features" + ] + }, + "BatchReadFeatureValues": { + "methods": [ + "batch_read_feature_values" + ] + }, + "CreateEntityType": { + "methods": [ + "create_entity_type" + ] + }, + "CreateFeature": { + "methods": [ + "create_feature" + ] + }, + "CreateFeaturestore": { + "methods": [ + "create_featurestore" + ] + }, + "DeleteEntityType": { + "methods": [ + "delete_entity_type" + ] + }, + "DeleteFeature": { + "methods": [ + "delete_feature" + ] + }, + "DeleteFeaturestore": { + "methods": [ + "delete_featurestore" + ] + }, + "ExportFeatureValues": { + "methods": [ + "export_feature_values" + ] + }, + "GetEntityType": { + "methods": [ + "get_entity_type" + ] + }, + "GetFeature": { + "methods": [ + "get_feature" + ] + }, + "GetFeaturestore": { + "methods": [ + "get_featurestore" + ] + }, + "ImportFeatureValues": { + "methods": [ + "import_feature_values" + ] + }, + "ListEntityTypes": { + "methods": [ + "list_entity_types" + ] + }, + "ListFeatures": { + "methods": [ + "list_features" + ] + }, + "ListFeaturestores": { + "methods": [ + "list_featurestores" + ] + }, + "SearchFeatures": { + "methods": [ + "search_features" + ] + }, + "UpdateEntityType": { + "methods": [ + "update_entity_type" + ] + }, + "UpdateFeature": { + "methods": [ + "update_feature" + ] + }, + "UpdateFeaturestore": { + "methods": [ + "update_featurestore" + ] + } + } + } + } + }, + "IndexEndpointService": { + "clients": { + "grpc": { + "libraryClient": "IndexEndpointServiceClient", + "rpcs": { + "CreateIndexEndpoint": { + "methods": [ + "create_index_endpoint" + ] + }, + "DeleteIndexEndpoint": { + "methods": [ + "delete_index_endpoint" + ] + }, + "DeployIndex": { + "methods": [ + "deploy_index" + ] + }, + "GetIndexEndpoint": { + "methods": [ + "get_index_endpoint" + ] + }, + "ListIndexEndpoints": { + "methods": [ + "list_index_endpoints" + ] + }, + "UndeployIndex": { + "methods": [ + "undeploy_index" + ] + }, + "UpdateIndexEndpoint": { + "methods": [ + "update_index_endpoint" + ] + } + } + }, + "grpc-async": { + "libraryClient": "IndexEndpointServiceAsyncClient", + "rpcs": { + "CreateIndexEndpoint": { + "methods": [ + "create_index_endpoint" + ] + }, + "DeleteIndexEndpoint": { + "methods": [ + "delete_index_endpoint" + ] + }, + "DeployIndex": { + "methods": [ + "deploy_index" + ] + }, + "GetIndexEndpoint": { + "methods": [ + "get_index_endpoint" + ] + }, + "ListIndexEndpoints": { + "methods": [ + "list_index_endpoints" + ] + }, + "UndeployIndex": { + "methods": [ + "undeploy_index" + ] + }, + "UpdateIndexEndpoint": { + "methods": [ + "update_index_endpoint" + ] + } + } + } + } + }, + "IndexService": { + "clients": { + "grpc": { + "libraryClient": "IndexServiceClient", + "rpcs": { + "CreateIndex": { + "methods": [ + "create_index" + ] + }, + "DeleteIndex": { + "methods": [ + "delete_index" + ] + }, + "GetIndex": { + "methods": [ + "get_index" + ] + }, + "ListIndexes": { + "methods": [ + "list_indexes" + ] + }, + "UpdateIndex": { + "methods": [ + "update_index" + ] + } + } + }, + "grpc-async": { + "libraryClient": "IndexServiceAsyncClient", + "rpcs": { + "CreateIndex": { + "methods": [ + "create_index" + ] + }, + "DeleteIndex": { + "methods": [ + "delete_index" + ] + }, + "GetIndex": { + "methods": [ + "get_index" + ] + }, + "ListIndexes": { + "methods": [ + "list_indexes" + ] + }, + "UpdateIndex": { + "methods": [ + "update_index" + ] + } + } + } + } + }, + "JobService": { + "clients": { + "grpc": { + "libraryClient": "JobServiceClient", + "rpcs": { + "CancelBatchPredictionJob": { + "methods": [ + "cancel_batch_prediction_job" + ] + }, + "CancelCustomJob": { + "methods": [ + "cancel_custom_job" + ] + }, + "CancelDataLabelingJob": { + "methods": [ + "cancel_data_labeling_job" + ] + }, + "CancelHyperparameterTuningJob": { + "methods": [ + "cancel_hyperparameter_tuning_job" + ] + }, + "CreateBatchPredictionJob": { + "methods": [ + "create_batch_prediction_job" + ] + }, + "CreateCustomJob": { + "methods": [ + "create_custom_job" + ] + }, + "CreateDataLabelingJob": { + "methods": [ + "create_data_labeling_job" + ] + }, + "CreateHyperparameterTuningJob": { + "methods": [ + "create_hyperparameter_tuning_job" + ] + }, + "CreateModelDeploymentMonitoringJob": { + "methods": [ + "create_model_deployment_monitoring_job" + ] + }, + "DeleteBatchPredictionJob": { + "methods": [ + "delete_batch_prediction_job" + ] + }, + "DeleteCustomJob": { + "methods": [ + "delete_custom_job" + ] + }, + "DeleteDataLabelingJob": { + "methods": [ + "delete_data_labeling_job" + ] + }, + "DeleteHyperparameterTuningJob": { + "methods": [ + "delete_hyperparameter_tuning_job" + ] + }, + "DeleteModelDeploymentMonitoringJob": { + "methods": [ + "delete_model_deployment_monitoring_job" + ] + }, + "GetBatchPredictionJob": { + "methods": [ + "get_batch_prediction_job" + ] + }, + "GetCustomJob": { + "methods": [ + "get_custom_job" + ] + }, + "GetDataLabelingJob": { + "methods": [ + "get_data_labeling_job" + ] + }, + "GetHyperparameterTuningJob": { + "methods": [ + "get_hyperparameter_tuning_job" + ] + }, + "GetModelDeploymentMonitoringJob": { + "methods": [ + "get_model_deployment_monitoring_job" + ] + }, + "ListBatchPredictionJobs": { + "methods": [ + "list_batch_prediction_jobs" + ] + }, + "ListCustomJobs": { + "methods": [ + "list_custom_jobs" + ] + }, + "ListDataLabelingJobs": { + "methods": [ + "list_data_labeling_jobs" + ] + }, + "ListHyperparameterTuningJobs": { + "methods": [ + "list_hyperparameter_tuning_jobs" + ] + }, + "ListModelDeploymentMonitoringJobs": { + "methods": [ + "list_model_deployment_monitoring_jobs" + ] + }, + "PauseModelDeploymentMonitoringJob": { + "methods": [ + "pause_model_deployment_monitoring_job" + ] + }, + "ResumeModelDeploymentMonitoringJob": { + "methods": [ + "resume_model_deployment_monitoring_job" + ] + }, + "SearchModelDeploymentMonitoringStatsAnomalies": { + "methods": [ + "search_model_deployment_monitoring_stats_anomalies" + ] + }, + "UpdateModelDeploymentMonitoringJob": { + "methods": [ + "update_model_deployment_monitoring_job" + ] + } + } + }, + "grpc-async": { + "libraryClient": "JobServiceAsyncClient", + "rpcs": { + "CancelBatchPredictionJob": { + "methods": [ + "cancel_batch_prediction_job" + ] + }, + "CancelCustomJob": { + "methods": [ + "cancel_custom_job" + ] + }, + "CancelDataLabelingJob": { + "methods": [ + "cancel_data_labeling_job" + ] + }, + "CancelHyperparameterTuningJob": { + "methods": [ + "cancel_hyperparameter_tuning_job" + ] + }, + "CreateBatchPredictionJob": { + "methods": [ + "create_batch_prediction_job" + ] + }, + "CreateCustomJob": { + "methods": [ + "create_custom_job" + ] + }, + "CreateDataLabelingJob": { + "methods": [ + "create_data_labeling_job" + ] + }, + "CreateHyperparameterTuningJob": { + "methods": [ + "create_hyperparameter_tuning_job" + ] + }, + "CreateModelDeploymentMonitoringJob": { + "methods": [ + "create_model_deployment_monitoring_job" + ] + }, + "DeleteBatchPredictionJob": { + "methods": [ + "delete_batch_prediction_job" + ] + }, + "DeleteCustomJob": { + "methods": [ + "delete_custom_job" + ] + }, + "DeleteDataLabelingJob": { + "methods": [ + "delete_data_labeling_job" + ] + }, + "DeleteHyperparameterTuningJob": { + "methods": [ + "delete_hyperparameter_tuning_job" + ] + }, + "DeleteModelDeploymentMonitoringJob": { + "methods": [ + "delete_model_deployment_monitoring_job" + ] + }, + "GetBatchPredictionJob": { + "methods": [ + "get_batch_prediction_job" + ] + }, + "GetCustomJob": { + "methods": [ + "get_custom_job" + ] + }, + "GetDataLabelingJob": { + "methods": [ + "get_data_labeling_job" + ] + }, + "GetHyperparameterTuningJob": { + "methods": [ + "get_hyperparameter_tuning_job" + ] + }, + "GetModelDeploymentMonitoringJob": { + "methods": [ + "get_model_deployment_monitoring_job" + ] + }, + "ListBatchPredictionJobs": { + "methods": [ + "list_batch_prediction_jobs" + ] + }, + "ListCustomJobs": { + "methods": [ + "list_custom_jobs" + ] + }, + "ListDataLabelingJobs": { + "methods": [ + "list_data_labeling_jobs" + ] + }, + "ListHyperparameterTuningJobs": { + "methods": [ + "list_hyperparameter_tuning_jobs" + ] + }, + "ListModelDeploymentMonitoringJobs": { + "methods": [ + "list_model_deployment_monitoring_jobs" + ] + }, + "PauseModelDeploymentMonitoringJob": { + "methods": [ + "pause_model_deployment_monitoring_job" + ] + }, + "ResumeModelDeploymentMonitoringJob": { + "methods": [ + "resume_model_deployment_monitoring_job" + ] + }, + "SearchModelDeploymentMonitoringStatsAnomalies": { + "methods": [ + "search_model_deployment_monitoring_stats_anomalies" + ] + }, + "UpdateModelDeploymentMonitoringJob": { + "methods": [ + "update_model_deployment_monitoring_job" + ] + } + } + } + } + }, + "MetadataService": { + "clients": { + "grpc": { + "libraryClient": "MetadataServiceClient", + "rpcs": { + "AddContextArtifactsAndExecutions": { + "methods": [ + "add_context_artifacts_and_executions" + ] + }, + "AddContextChildren": { + "methods": [ + "add_context_children" + ] + }, + "AddExecutionEvents": { + "methods": [ + "add_execution_events" + ] + }, + "CreateArtifact": { + "methods": [ + "create_artifact" + ] + }, + "CreateContext": { + "methods": [ + "create_context" + ] + }, + "CreateExecution": { + "methods": [ + "create_execution" + ] + }, + "CreateMetadataSchema": { + "methods": [ + "create_metadata_schema" + ] + }, + "CreateMetadataStore": { + "methods": [ + "create_metadata_store" + ] + }, + "DeleteContext": { + "methods": [ + "delete_context" + ] + }, + "DeleteMetadataStore": { + "methods": [ + "delete_metadata_store" + ] + }, + "GetArtifact": { + "methods": [ + "get_artifact" + ] + }, + "GetContext": { + "methods": [ + "get_context" + ] + }, + "GetExecution": { + "methods": [ + "get_execution" + ] + }, + "GetMetadataSchema": { + "methods": [ + "get_metadata_schema" + ] + }, + "GetMetadataStore": { + "methods": [ + "get_metadata_store" + ] + }, + "ListArtifacts": { + "methods": [ + "list_artifacts" + ] + }, + "ListContexts": { + "methods": [ + "list_contexts" + ] + }, + "ListExecutions": { + "methods": [ + "list_executions" + ] + }, + "ListMetadataSchemas": { + "methods": [ + "list_metadata_schemas" + ] + }, + "ListMetadataStores": { + "methods": [ + "list_metadata_stores" + ] + }, + "QueryArtifactLineageSubgraph": { + "methods": [ + "query_artifact_lineage_subgraph" + ] + }, + "QueryContextLineageSubgraph": { + "methods": [ + "query_context_lineage_subgraph" + ] + }, + "QueryExecutionInputsAndOutputs": { + "methods": [ + "query_execution_inputs_and_outputs" + ] + }, + "UpdateArtifact": { + "methods": [ + "update_artifact" + ] + }, + "UpdateContext": { + "methods": [ + "update_context" + ] + }, + "UpdateExecution": { + "methods": [ + "update_execution" + ] + } + } + }, + "grpc-async": { + "libraryClient": "MetadataServiceAsyncClient", + "rpcs": { + "AddContextArtifactsAndExecutions": { + "methods": [ + "add_context_artifacts_and_executions" + ] + }, + "AddContextChildren": { + "methods": [ + "add_context_children" + ] + }, + "AddExecutionEvents": { + "methods": [ + "add_execution_events" + ] + }, + "CreateArtifact": { + "methods": [ + "create_artifact" + ] + }, + "CreateContext": { + "methods": [ + "create_context" + ] + }, + "CreateExecution": { + "methods": [ + "create_execution" + ] + }, + "CreateMetadataSchema": { + "methods": [ + "create_metadata_schema" + ] + }, + "CreateMetadataStore": { + "methods": [ + "create_metadata_store" + ] + }, + "DeleteContext": { + "methods": [ + "delete_context" + ] + }, + "DeleteMetadataStore": { + "methods": [ + "delete_metadata_store" + ] + }, + "GetArtifact": { + "methods": [ + "get_artifact" + ] + }, + "GetContext": { + "methods": [ + "get_context" + ] + }, + "GetExecution": { + "methods": [ + "get_execution" + ] + }, + "GetMetadataSchema": { + "methods": [ + "get_metadata_schema" + ] + }, + "GetMetadataStore": { + "methods": [ + "get_metadata_store" + ] + }, + "ListArtifacts": { + "methods": [ + "list_artifacts" + ] + }, + "ListContexts": { + "methods": [ + "list_contexts" + ] + }, + "ListExecutions": { + "methods": [ + "list_executions" + ] + }, + "ListMetadataSchemas": { + "methods": [ + "list_metadata_schemas" + ] + }, + "ListMetadataStores": { + "methods": [ + "list_metadata_stores" + ] + }, + "QueryArtifactLineageSubgraph": { + "methods": [ + "query_artifact_lineage_subgraph" + ] + }, + "QueryContextLineageSubgraph": { + "methods": [ + "query_context_lineage_subgraph" + ] + }, + "QueryExecutionInputsAndOutputs": { + "methods": [ + "query_execution_inputs_and_outputs" + ] + }, + "UpdateArtifact": { + "methods": [ + "update_artifact" + ] + }, + "UpdateContext": { + "methods": [ + "update_context" + ] + }, + "UpdateExecution": { + "methods": [ + "update_execution" + ] + } + } + } + } + }, + "MigrationService": { + "clients": { + "grpc": { + "libraryClient": "MigrationServiceClient", + "rpcs": { + "BatchMigrateResources": { + "methods": [ + "batch_migrate_resources" + ] + }, + "SearchMigratableResources": { + "methods": [ + "search_migratable_resources" + ] + } + } + }, + "grpc-async": { + "libraryClient": "MigrationServiceAsyncClient", + "rpcs": { + "BatchMigrateResources": { + "methods": [ + "batch_migrate_resources" + ] + }, + "SearchMigratableResources": { + "methods": [ + "search_migratable_resources" + ] + } + } + } + } + }, + "ModelService": { + "clients": { + "grpc": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "DeleteModel": { + "methods": [ + "delete_model" + ] + }, + "ExportModel": { + "methods": [ + "export_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetModelEvaluation": { + "methods": [ + "get_model_evaluation" + ] + }, + "GetModelEvaluationSlice": { + "methods": [ + "get_model_evaluation_slice" + ] + }, + "ListModelEvaluationSlices": { + "methods": [ + "list_model_evaluation_slices" + ] + }, + "ListModelEvaluations": { + "methods": [ + "list_model_evaluations" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "UpdateModel": { + "methods": [ + "update_model" + ] + }, + "UploadModel": { + "methods": [ + "upload_model" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ModelServiceAsyncClient", + "rpcs": { + "DeleteModel": { + "methods": [ + "delete_model" + ] + }, + "ExportModel": { + "methods": [ + "export_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetModelEvaluation": { + "methods": [ + "get_model_evaluation" + ] + }, + "GetModelEvaluationSlice": { + "methods": [ + "get_model_evaluation_slice" + ] + }, + "ListModelEvaluationSlices": { + "methods": [ + "list_model_evaluation_slices" + ] + }, + "ListModelEvaluations": { + "methods": [ + "list_model_evaluations" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "UpdateModel": { + "methods": [ + "update_model" + ] + }, + "UploadModel": { + "methods": [ + "upload_model" + ] + } + } + } + } + }, + "PipelineService": { + "clients": { + "grpc": { + "libraryClient": "PipelineServiceClient", + "rpcs": { + "CancelPipelineJob": { + "methods": [ + "cancel_pipeline_job" + ] + }, + "CancelTrainingPipeline": { + "methods": [ + "cancel_training_pipeline" + ] + }, + "CreatePipelineJob": { + "methods": [ + "create_pipeline_job" + ] + }, + "CreateTrainingPipeline": { + "methods": [ + "create_training_pipeline" + ] + }, + "DeletePipelineJob": { + "methods": [ + "delete_pipeline_job" + ] + }, + "DeleteTrainingPipeline": { + "methods": [ + "delete_training_pipeline" + ] + }, + "GetPipelineJob": { + "methods": [ + "get_pipeline_job" + ] + }, + "GetTrainingPipeline": { + "methods": [ + "get_training_pipeline" + ] + }, + "ListPipelineJobs": { + "methods": [ + "list_pipeline_jobs" + ] + }, + "ListTrainingPipelines": { + "methods": [ + "list_training_pipelines" + ] + } + } + }, + "grpc-async": { + "libraryClient": "PipelineServiceAsyncClient", + "rpcs": { + "CancelPipelineJob": { + "methods": [ + "cancel_pipeline_job" + ] + }, + "CancelTrainingPipeline": { + "methods": [ + "cancel_training_pipeline" + ] + }, + "CreatePipelineJob": { + "methods": [ + "create_pipeline_job" + ] + }, + "CreateTrainingPipeline": { + "methods": [ + "create_training_pipeline" + ] + }, + "DeletePipelineJob": { + "methods": [ + "delete_pipeline_job" + ] + }, + "DeleteTrainingPipeline": { + "methods": [ + "delete_training_pipeline" + ] + }, + "GetPipelineJob": { + "methods": [ + "get_pipeline_job" + ] + }, + "GetTrainingPipeline": { + "methods": [ + "get_training_pipeline" + ] + }, + "ListPipelineJobs": { + "methods": [ + "list_pipeline_jobs" + ] + }, + "ListTrainingPipelines": { + "methods": [ + "list_training_pipelines" + ] + } + } + } + } + }, + "PredictionService": { + "clients": { + "grpc": { + "libraryClient": "PredictionServiceClient", + "rpcs": { + "Explain": { + "methods": [ + "explain" + ] + }, + "Predict": { + "methods": [ + "predict" + ] + } + } + }, + "grpc-async": { + "libraryClient": "PredictionServiceAsyncClient", + "rpcs": { + "Explain": { + "methods": [ + "explain" + ] + }, + "Predict": { + "methods": [ + "predict" + ] + } + } + } + } + }, + "SpecialistPoolService": { + "clients": { + "grpc": { + "libraryClient": "SpecialistPoolServiceClient", + "rpcs": { + "CreateSpecialistPool": { + "methods": [ + "create_specialist_pool" + ] + }, + "DeleteSpecialistPool": { + "methods": [ + "delete_specialist_pool" + ] + }, + "GetSpecialistPool": { + "methods": [ + "get_specialist_pool" + ] + }, + "ListSpecialistPools": { + "methods": [ + "list_specialist_pools" + ] + }, + "UpdateSpecialistPool": { + "methods": [ + "update_specialist_pool" + ] + } + } + }, + "grpc-async": { + "libraryClient": "SpecialistPoolServiceAsyncClient", + "rpcs": { + "CreateSpecialistPool": { + "methods": [ + "create_specialist_pool" + ] + }, + "DeleteSpecialistPool": { + "methods": [ + "delete_specialist_pool" + ] + }, + "GetSpecialistPool": { + "methods": [ + "get_specialist_pool" + ] + }, + "ListSpecialistPools": { + "methods": [ + "list_specialist_pools" + ] + }, + "UpdateSpecialistPool": { + "methods": [ + "update_specialist_pool" + ] + } + } + } + } + }, + "TensorboardService": { + "clients": { + "grpc": { + "libraryClient": "TensorboardServiceClient", + "rpcs": { + "CreateTensorboard": { + "methods": [ + "create_tensorboard" + ] + }, + "CreateTensorboardExperiment": { + "methods": [ + "create_tensorboard_experiment" + ] + }, + "CreateTensorboardRun": { + "methods": [ + "create_tensorboard_run" + ] + }, + "CreateTensorboardTimeSeries": { + "methods": [ + "create_tensorboard_time_series" + ] + }, + "DeleteTensorboard": { + "methods": [ + "delete_tensorboard" + ] + }, + "DeleteTensorboardExperiment": { + "methods": [ + "delete_tensorboard_experiment" + ] + }, + "DeleteTensorboardRun": { + "methods": [ + "delete_tensorboard_run" + ] + }, + "DeleteTensorboardTimeSeries": { + "methods": [ + "delete_tensorboard_time_series" + ] + }, + "ExportTensorboardTimeSeriesData": { + "methods": [ + "export_tensorboard_time_series_data" + ] + }, + "GetTensorboard": { + "methods": [ + "get_tensorboard" + ] + }, + "GetTensorboardExperiment": { + "methods": [ + "get_tensorboard_experiment" + ] + }, + "GetTensorboardRun": { + "methods": [ + "get_tensorboard_run" + ] + }, + "GetTensorboardTimeSeries": { + "methods": [ + "get_tensorboard_time_series" + ] + }, + "ListTensorboardExperiments": { + "methods": [ + "list_tensorboard_experiments" + ] + }, + "ListTensorboardRuns": { + "methods": [ + "list_tensorboard_runs" + ] + }, + "ListTensorboardTimeSeries": { + "methods": [ + "list_tensorboard_time_series" + ] + }, + "ListTensorboards": { + "methods": [ + "list_tensorboards" + ] + }, + "ReadTensorboardBlobData": { + "methods": [ + "read_tensorboard_blob_data" + ] + }, + "ReadTensorboardTimeSeriesData": { + "methods": [ + "read_tensorboard_time_series_data" + ] + }, + "UpdateTensorboard": { + "methods": [ + "update_tensorboard" + ] + }, + "UpdateTensorboardExperiment": { + "methods": [ + "update_tensorboard_experiment" + ] + }, + "UpdateTensorboardRun": { + "methods": [ + "update_tensorboard_run" + ] + }, + "UpdateTensorboardTimeSeries": { + "methods": [ + "update_tensorboard_time_series" + ] + }, + "WriteTensorboardRunData": { + "methods": [ + "write_tensorboard_run_data" + ] + } + } + }, + "grpc-async": { + "libraryClient": "TensorboardServiceAsyncClient", + "rpcs": { + "CreateTensorboard": { + "methods": [ + "create_tensorboard" + ] + }, + "CreateTensorboardExperiment": { + "methods": [ + "create_tensorboard_experiment" + ] + }, + "CreateTensorboardRun": { + "methods": [ + "create_tensorboard_run" + ] + }, + "CreateTensorboardTimeSeries": { + "methods": [ + "create_tensorboard_time_series" + ] + }, + "DeleteTensorboard": { + "methods": [ + "delete_tensorboard" + ] + }, + "DeleteTensorboardExperiment": { + "methods": [ + "delete_tensorboard_experiment" + ] + }, + "DeleteTensorboardRun": { + "methods": [ + "delete_tensorboard_run" + ] + }, + "DeleteTensorboardTimeSeries": { + "methods": [ + "delete_tensorboard_time_series" + ] + }, + "ExportTensorboardTimeSeriesData": { + "methods": [ + "export_tensorboard_time_series_data" + ] + }, + "GetTensorboard": { + "methods": [ + "get_tensorboard" + ] + }, + "GetTensorboardExperiment": { + "methods": [ + "get_tensorboard_experiment" + ] + }, + "GetTensorboardRun": { + "methods": [ + "get_tensorboard_run" + ] + }, + "GetTensorboardTimeSeries": { + "methods": [ + "get_tensorboard_time_series" + ] + }, + "ListTensorboardExperiments": { + "methods": [ + "list_tensorboard_experiments" + ] + }, + "ListTensorboardRuns": { + "methods": [ + "list_tensorboard_runs" + ] + }, + "ListTensorboardTimeSeries": { + "methods": [ + "list_tensorboard_time_series" + ] + }, + "ListTensorboards": { + "methods": [ + "list_tensorboards" + ] + }, + "ReadTensorboardBlobData": { + "methods": [ + "read_tensorboard_blob_data" + ] + }, + "ReadTensorboardTimeSeriesData": { + "methods": [ + "read_tensorboard_time_series_data" + ] + }, + "UpdateTensorboard": { + "methods": [ + "update_tensorboard" + ] + }, + "UpdateTensorboardExperiment": { + "methods": [ + "update_tensorboard_experiment" + ] + }, + "UpdateTensorboardRun": { + "methods": [ + "update_tensorboard_run" + ] + }, + "UpdateTensorboardTimeSeries": { + "methods": [ + "update_tensorboard_time_series" + ] + }, + "WriteTensorboardRunData": { + "methods": [ + "write_tensorboard_run_data" + ] + } + } + } + } + }, + "VizierService": { + "clients": { + "grpc": { + "libraryClient": "VizierServiceClient", + "rpcs": { + "AddTrialMeasurement": { + "methods": [ + "add_trial_measurement" + ] + }, + "CheckTrialEarlyStoppingState": { + "methods": [ + "check_trial_early_stopping_state" + ] + }, + "CompleteTrial": { + "methods": [ + "complete_trial" + ] + }, + "CreateStudy": { + "methods": [ + "create_study" + ] + }, + "CreateTrial": { + "methods": [ + "create_trial" + ] + }, + "DeleteStudy": { + "methods": [ + "delete_study" + ] + }, + "DeleteTrial": { + "methods": [ + "delete_trial" + ] + }, + "GetStudy": { + "methods": [ + "get_study" + ] + }, + "GetTrial": { + "methods": [ + "get_trial" + ] + }, + "ListOptimalTrials": { + "methods": [ + "list_optimal_trials" + ] + }, + "ListStudies": { + "methods": [ + "list_studies" + ] + }, + "ListTrials": { + "methods": [ + "list_trials" + ] + }, + "LookupStudy": { + "methods": [ + "lookup_study" + ] + }, + "StopTrial": { + "methods": [ + "stop_trial" + ] + }, + "SuggestTrials": { + "methods": [ + "suggest_trials" + ] + } + } + }, + "grpc-async": { + "libraryClient": "VizierServiceAsyncClient", + "rpcs": { + "AddTrialMeasurement": { + "methods": [ + "add_trial_measurement" + ] + }, + "CheckTrialEarlyStoppingState": { + "methods": [ + "check_trial_early_stopping_state" + ] + }, + "CompleteTrial": { + "methods": [ + "complete_trial" + ] + }, + "CreateStudy": { + "methods": [ + "create_study" + ] + }, + "CreateTrial": { + "methods": [ + "create_trial" + ] + }, + "DeleteStudy": { + "methods": [ + "delete_study" + ] + }, + "DeleteTrial": { + "methods": [ + "delete_trial" + ] + }, + "GetStudy": { + "methods": [ + "get_study" + ] + }, + "GetTrial": { + "methods": [ + "get_trial" + ] + }, + "ListOptimalTrials": { + "methods": [ + "list_optimal_trials" + ] + }, + "ListStudies": { + "methods": [ + "list_studies" + ] + }, + "ListTrials": { + "methods": [ + "list_trials" + ] + }, + "LookupStudy": { + "methods": [ + "lookup_study" + ] + }, + "StopTrial": { + "methods": [ + "stop_trial" + ] + }, + "SuggestTrials": { + "methods": [ + "suggest_trials" + ] + } + } + } + } + } + } +} diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py index f58bc25d1b..9bb0bbda1c 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -1860,6 +1860,95 @@ async def batch_read_feature_values( # Done; return the response. return response + async def export_feature_values( + self, + request: featurestore_service.ExportFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Exports Feature values from all the entities of a + target EntityType. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ExportFeatureValuesRequest`): + The request object. Request message for + ``FeaturestoreService.ExportFeatureValues``. + entity_type (:class:`str`): + Required. The resource name of the EntityType from which + to export Feature values. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.ExportFeatureValuesResponse` + Response message for + ``FeaturestoreService.ExportFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = featurestore_service.ExportFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_feature_values, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + featurestore_service.ExportFeatureValuesResponse, + metadata_type=featurestore_service.ExportFeatureValuesOperationMetadata, + ) + + # Done; return the response. + return response + async def search_features( self, request: featurestore_service.SearchFeaturesRequest = None, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index 2b9991c9ba..b303ef4c25 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -2089,6 +2089,96 @@ def batch_read_feature_values( # Done; return the response. return response + def export_feature_values( + self, + request: featurestore_service.ExportFeatureValuesRequest = None, + *, + entity_type: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Exports Feature values from all the entities of a + target EntityType. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ExportFeatureValuesRequest): + The request object. Request message for + ``FeaturestoreService.ExportFeatureValues``. + entity_type (str): + Required. The resource name of the EntityType from which + to export Feature values. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + + This corresponds to the ``entity_type`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.ExportFeatureValuesResponse` + Response message for + ``FeaturestoreService.ExportFeatureValues``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([entity_type]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a featurestore_service.ExportFeatureValuesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, featurestore_service.ExportFeatureValuesRequest): + request = featurestore_service.ExportFeatureValuesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if entity_type is not None: + request.entity_type = entity_type + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.export_feature_values] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("entity_type", request.entity_type),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + featurestore_service.ExportFeatureValuesResponse, + metadata_type=featurestore_service.ExportFeatureValuesOperationMetadata, + ) + + # Done; return the response. + return response + def search_features( self, request: featurestore_service.SearchFeaturesRequest = None, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py index 2f633c4f81..9a6277cf39 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py @@ -173,6 +173,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.export_feature_values: gapic_v1.method.wrap_method( + self.export_feature_values, + default_timeout=None, + client_info=client_info, + ), self.search_features: gapic_v1.method.wrap_method( self.search_features, default_timeout=None, client_info=client_info, ), @@ -358,6 +363,15 @@ def batch_read_feature_values( ]: raise NotImplementedError() + @property + def export_feature_values( + self, + ) -> typing.Callable[ + [featurestore_service.ExportFeatureValuesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + @property def search_features( self, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py index ab15959efd..27c255d8a6 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py @@ -767,6 +767,35 @@ def batch_read_feature_values( ) return self._stubs["batch_read_feature_values"] + @property + def export_feature_values( + self, + ) -> Callable[ + [featurestore_service.ExportFeatureValuesRequest], operations.Operation + ]: + r"""Return a callable for the export feature values method over gRPC. + + Exports Feature values from all the entities of a + target EntityType. + + Returns: + Callable[[~.ExportFeatureValuesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_feature_values" not in self._stubs: + self._stubs["export_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ExportFeatureValues", + request_serializer=featurestore_service.ExportFeatureValuesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_feature_values"] + @property def search_features( self, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py index e0a4e35394..148ac3c1a9 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py @@ -793,6 +793,36 @@ def batch_read_feature_values( ) return self._stubs["batch_read_feature_values"] + @property + def export_feature_values( + self, + ) -> Callable[ + [featurestore_service.ExportFeatureValuesRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the export feature values method over gRPC. + + Exports Feature values from all the entities of a + target EntityType. + + Returns: + Callable[[~.ExportFeatureValuesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_feature_values" not in self._stubs: + self._stubs["export_feature_values"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.FeaturestoreService/ExportFeatureValues", + request_serializer=featurestore_service.ExportFeatureValuesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_feature_values"] + @property def search_features( self, diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py index 06dd2a9d72..bedc754f6f 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py @@ -58,10 +58,6 @@ class IndexEndpointServiceAsyncClient: parse_index_endpoint_path = staticmethod( IndexEndpointServiceClient.parse_index_endpoint_path ) - index_endpoint_path = staticmethod(IndexEndpointServiceClient.index_endpoint_path) - parse_index_endpoint_path = staticmethod( - IndexEndpointServiceClient.parse_index_endpoint_path - ) common_billing_account_path = staticmethod( IndexEndpointServiceClient.common_billing_account_path diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py index 373410e6e7..b46d9af934 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py @@ -197,22 +197,6 @@ def parse_index_endpoint_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} - @staticmethod - def index_endpoint_path(project: str, location: str, index_endpoint: str,) -> str: - """Return a fully-qualified index_endpoint string.""" - return "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( - project=project, location=location, index_endpoint=index_endpoint, - ) - - @staticmethod - def parse_index_endpoint_path(path: str) -> Dict[str, str]: - """Parse a index_endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/indexEndpoints/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - @staticmethod def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py index 346bd1bc1e..53b8c52df8 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py @@ -447,7 +447,7 @@ async def update_index( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateIndexRequest`): The request object. Request message for - [IndexService.UpdateModel][]. + ``IndexService.UpdateIndex``. index (:class:`google.cloud.aiplatform_v1beta1.types.Index`): Required. The Index which updates the resource on the server. diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_service/client.py index b90771f405..d12b1fe06b 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/client.py @@ -638,7 +638,7 @@ def update_index( Args: request (google.cloud.aiplatform_v1beta1.types.UpdateIndexRequest): The request object. Request message for - [IndexService.UpdateModel][]. + ``IndexService.UpdateIndex``. index (google.cloud.aiplatform_v1beta1.types.Index): Required. The Index which updates the resource on the server. diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 52061202bf..faa83aa192 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -109,6 +109,10 @@ class JobServiceAsyncClient: parse_model_deployment_monitoring_job_path = staticmethod( JobServiceClient.parse_model_deployment_monitoring_job_path ) + network_path = staticmethod(JobServiceClient.network_path) + parse_network_path = staticmethod(JobServiceClient.parse_network_path) + tensorboard_path = staticmethod(JobServiceClient.tensorboard_path) + parse_tensorboard_path = staticmethod(JobServiceClient.parse_tensorboard_path) trial_path = staticmethod(JobServiceClient.trial_path) parse_trial_path = staticmethod(JobServiceClient.parse_trial_path) @@ -667,7 +671,7 @@ async def create_data_labeling_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.CreateDataLabelingJobRequest`): The request object. Request message for - [DataLabelingJobService.CreateDataLabelingJob][]. + ``JobService.CreateDataLabelingJob``. parent (:class:`str`): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` @@ -750,7 +754,7 @@ async def get_data_labeling_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.GetDataLabelingJobRequest`): The request object. Request message for - [DataLabelingJobService.GetDataLabelingJob][]. + ``JobService.GetDataLabelingJob``. name (:class:`str`): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` @@ -824,7 +828,7 @@ async def list_data_labeling_jobs( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsRequest`): The request object. Request message for - [DataLabelingJobService.ListDataLabelingJobs][]. + ``JobService.ListDataLabelingJobs``. parent (:class:`str`): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` @@ -1002,7 +1006,7 @@ async def cancel_data_labeling_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.CancelDataLabelingJobRequest`): The request object. Request message for - [DataLabelingJobService.CancelDataLabelingJob][]. + ``JobService.CancelDataLabelingJob``. name (:class:`str`): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` @@ -1914,7 +1918,7 @@ async def create_model_deployment_monitoring_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.CreateModelDeploymentMonitoringJobRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.CreateModelDeploymentMonitoringJob][]. + ``JobService.CreateModelDeploymentMonitoringJob``. parent (:class:`str`): Required. The parent of the ModelDeploymentMonitoringJob. Format: @@ -2002,7 +2006,7 @@ async def search_model_deployment_monitoring_stats_anomalies( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies``. model_deployment_monitoring_job (:class:`str`): Required. ModelDeploymentMonitoring Job resource name. Format: @@ -2028,7 +2032,7 @@ async def search_model_deployment_monitoring_stats_anomalies( Returns: google.cloud.aiplatform_v1beta1.services.job_service.pagers.SearchModelDeploymentMonitoringStatsAnomaliesAsyncPager: Response message for - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies``. Iterating over this object will yield results and resolve additional pages automatically. @@ -2103,7 +2107,7 @@ async def get_model_deployment_monitoring_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.GetModelDeploymentMonitoringJobRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.GetModelDeploymentMonitoringJob][]. + ``JobService.GetModelDeploymentMonitoringJob``. name (:class:`str`): Required. The resource name of the ModelDeploymentMonitoringJob. Format: @@ -2180,7 +2184,7 @@ async def list_model_deployment_monitoring_jobs( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + ``JobService.ListModelDeploymentMonitoringJobs``. parent (:class:`str`): Required. The parent of the ModelDeploymentMonitoringJob. Format: @@ -2199,7 +2203,7 @@ async def list_model_deployment_monitoring_jobs( Returns: google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListModelDeploymentMonitoringJobsAsyncPager: Response message for - [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + ``JobService.ListModelDeploymentMonitoringJobs``. Iterating over this object will yield results and resolve additional pages automatically. @@ -2264,7 +2268,7 @@ async def update_model_deployment_monitoring_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateModelDeploymentMonitoringJobRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + ``JobService.UpdateModelDeploymentMonitoringJob``. model_deployment_monitoring_job (:class:`google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob`): Required. The model monitoring configuration which replaces the @@ -2365,7 +2369,7 @@ async def delete_model_deployment_monitoring_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteModelDeploymentMonitoringJobRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.DeleteModelDeploymentMonitoringJob][]. + ``JobService.DeleteModelDeploymentMonitoringJob``. name (:class:`str`): Required. The resource name of the model monitoring job to delete. Format: @@ -2463,7 +2467,7 @@ async def pause_model_deployment_monitoring_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.PauseModelDeploymentMonitoringJobRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.PauseModelDeploymentMonitoringJob][]. + ``JobService.PauseModelDeploymentMonitoringJob``. name (:class:`str`): Required. The resource name of the ModelDeploymentMonitoringJob to pause. Format: @@ -2532,7 +2536,7 @@ async def resume_model_deployment_monitoring_job( Args: request (:class:`google.cloud.aiplatform_v1beta1.types.ResumeModelDeploymentMonitoringJobRequest`): The request object. Request message for - [ModelDeploymentMonitoringJobService.ResumeModelDeploymentMonitoringJob][]. + ``JobService.ResumeModelDeploymentMonitoringJob``. name (:class:`str`): Required. The resource name of the ModelDeploymentMonitoringJob to resume. Format: diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index 6f649532af..b0e1e65586 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -332,6 +332,37 @@ def parse_model_deployment_monitoring_job_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def network_path(project: str, network: str,) -> str: + """Return a fully-qualified network string.""" + return "projects/{project}/global/networks/{network}".format( + project=project, network=network, + ) + + @staticmethod + def parse_network_path(path: str) -> Dict[str, str]: + """Parse a network path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/global/networks/(?P.+?)$", path + ) + return m.groupdict() if m else {} + + @staticmethod + def tensorboard_path(project: str, location: str, tensorboard: str,) -> str: + """Return a fully-qualified tensorboard string.""" + return "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( + project=project, location=location, tensorboard=tensorboard, + ) + + @staticmethod + def parse_tensorboard_path(path: str) -> Dict[str, str]: + """Parse a tensorboard path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/tensorboards/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def trial_path(project: str, location: str, study: str, trial: str,) -> str: """Return a fully-qualified trial string.""" @@ -964,7 +995,7 @@ def create_data_labeling_job( Args: request (google.cloud.aiplatform_v1beta1.types.CreateDataLabelingJobRequest): The request object. Request message for - [DataLabelingJobService.CreateDataLabelingJob][]. + ``JobService.CreateDataLabelingJob``. parent (str): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` @@ -1048,7 +1079,7 @@ def get_data_labeling_job( Args: request (google.cloud.aiplatform_v1beta1.types.GetDataLabelingJobRequest): The request object. Request message for - [DataLabelingJobService.GetDataLabelingJob][]. + ``JobService.GetDataLabelingJob``. name (str): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` @@ -1123,7 +1154,7 @@ def list_data_labeling_jobs( Args: request (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsRequest): The request object. Request message for - [DataLabelingJobService.ListDataLabelingJobs][]. + ``JobService.ListDataLabelingJobs``. parent (str): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` @@ -1303,7 +1334,7 @@ def cancel_data_labeling_job( Args: request (google.cloud.aiplatform_v1beta1.types.CancelDataLabelingJobRequest): The request object. Request message for - [DataLabelingJobService.CancelDataLabelingJob][]. + ``JobService.CancelDataLabelingJob``. name (str): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` @@ -2244,7 +2275,7 @@ def create_model_deployment_monitoring_job( Args: request (google.cloud.aiplatform_v1beta1.types.CreateModelDeploymentMonitoringJobRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.CreateModelDeploymentMonitoringJob][]. + ``JobService.CreateModelDeploymentMonitoringJob``. parent (str): Required. The parent of the ModelDeploymentMonitoringJob. Format: @@ -2339,7 +2370,7 @@ def search_model_deployment_monitoring_stats_anomalies( Args: request (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies``. model_deployment_monitoring_job (str): Required. ModelDeploymentMonitoring Job resource name. Format: @@ -2365,7 +2396,7 @@ def search_model_deployment_monitoring_stats_anomalies( Returns: google.cloud.aiplatform_v1beta1.services.job_service.pagers.SearchModelDeploymentMonitoringStatsAnomaliesPager: Response message for - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies``. Iterating over this object will yield results and resolve additional pages automatically. @@ -2447,7 +2478,7 @@ def get_model_deployment_monitoring_job( Args: request (google.cloud.aiplatform_v1beta1.types.GetModelDeploymentMonitoringJobRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.GetModelDeploymentMonitoringJob][]. + ``JobService.GetModelDeploymentMonitoringJob``. name (str): Required. The resource name of the ModelDeploymentMonitoringJob. Format: @@ -2527,7 +2558,7 @@ def list_model_deployment_monitoring_jobs( Args: request (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + ``JobService.ListModelDeploymentMonitoringJobs``. parent (str): Required. The parent of the ModelDeploymentMonitoringJob. Format: @@ -2546,7 +2577,7 @@ def list_model_deployment_monitoring_jobs( Returns: google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListModelDeploymentMonitoringJobsPager: Response message for - [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + ``JobService.ListModelDeploymentMonitoringJobs``. Iterating over this object will yield results and resolve additional pages automatically. @@ -2616,7 +2647,7 @@ def update_model_deployment_monitoring_job( Args: request (google.cloud.aiplatform_v1beta1.types.UpdateModelDeploymentMonitoringJobRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + ``JobService.UpdateModelDeploymentMonitoringJob``. model_deployment_monitoring_job (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob): Required. The model monitoring configuration which replaces the @@ -2724,7 +2755,7 @@ def delete_model_deployment_monitoring_job( Args: request (google.cloud.aiplatform_v1beta1.types.DeleteModelDeploymentMonitoringJobRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.DeleteModelDeploymentMonitoringJob][]. + ``JobService.DeleteModelDeploymentMonitoringJob``. name (str): Required. The resource name of the model monitoring job to delete. Format: @@ -2827,7 +2858,7 @@ def pause_model_deployment_monitoring_job( Args: request (google.cloud.aiplatform_v1beta1.types.PauseModelDeploymentMonitoringJobRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.PauseModelDeploymentMonitoringJob][]. + ``JobService.PauseModelDeploymentMonitoringJob``. name (str): Required. The resource name of the ModelDeploymentMonitoringJob to pause. Format: @@ -2901,7 +2932,7 @@ def resume_model_deployment_monitoring_job( Args: request (google.cloud.aiplatform_v1beta1.types.ResumeModelDeploymentMonitoringJobRequest): The request object. Request message for - [ModelDeploymentMonitoringJobService.ResumeModelDeploymentMonitoringJob][]. + ``JobService.ResumeModelDeploymentMonitoringJob``. name (str): Required. The resource name of the ModelDeploymentMonitoringJob to resume. Format: diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index d4324e3089..501f21183f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -180,32 +180,32 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 19ba32207c..a8e4ad20c8 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -34,6 +34,8 @@ from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_job +from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -58,10 +60,24 @@ class PipelineServiceAsyncClient: DEFAULT_ENDPOINT = PipelineServiceClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = PipelineServiceClient.DEFAULT_MTLS_ENDPOINT + artifact_path = staticmethod(PipelineServiceClient.artifact_path) + parse_artifact_path = staticmethod(PipelineServiceClient.parse_artifact_path) + context_path = staticmethod(PipelineServiceClient.context_path) + parse_context_path = staticmethod(PipelineServiceClient.parse_context_path) + custom_job_path = staticmethod(PipelineServiceClient.custom_job_path) + parse_custom_job_path = staticmethod(PipelineServiceClient.parse_custom_job_path) endpoint_path = staticmethod(PipelineServiceClient.endpoint_path) parse_endpoint_path = staticmethod(PipelineServiceClient.parse_endpoint_path) + execution_path = staticmethod(PipelineServiceClient.execution_path) + parse_execution_path = staticmethod(PipelineServiceClient.parse_execution_path) model_path = staticmethod(PipelineServiceClient.model_path) parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) + network_path = staticmethod(PipelineServiceClient.network_path) + parse_network_path = staticmethod(PipelineServiceClient.parse_network_path) + pipeline_job_path = staticmethod(PipelineServiceClient.pipeline_job_path) + parse_pipeline_job_path = staticmethod( + PipelineServiceClient.parse_pipeline_job_path + ) training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) parse_training_pipeline_path = staticmethod( PipelineServiceClient.parse_training_pipeline_path @@ -613,6 +629,432 @@ async def cancel_training_pipeline( request, retry=retry, timeout=timeout, metadata=metadata, ) + async def create_pipeline_job( + self, + request: pipeline_service.CreatePipelineJobRequest = None, + *, + parent: str = None, + pipeline_job: gca_pipeline_job.PipelineJob = None, + pipeline_job_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_pipeline_job.PipelineJob: + r"""Creates a PipelineJob. A PipelineJob will run + immediately when created. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreatePipelineJobRequest`): + The request object. Request message for + ``PipelineService.CreatePipelineJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the PipelineJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + pipeline_job (:class:`google.cloud.aiplatform_v1beta1.types.PipelineJob`): + Required. The PipelineJob to create. + This corresponds to the ``pipeline_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + pipeline_job_id (:class:`str`): + The ID to use for the PipelineJob, which will become the + final component of the PipelineJob name. If not + provided, an ID will be automatically generated. + + This value should be less than 128 characters, and valid + characters are /[a-z][0-9]-/. + + This corresponds to the ``pipeline_job_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.PipelineJob: + An instance of a machine learning + PipelineJob. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.CreatePipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if pipeline_job is not None: + request.pipeline_job = pipeline_job + if pipeline_job_id is not None: + request.pipeline_job_id = pipeline_job_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_pipeline_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_pipeline_job( + self, + request: pipeline_service.GetPipelineJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pipeline_job.PipelineJob: + r"""Gets a PipelineJob. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetPipelineJobRequest`): + The request object. Request message for + ``PipelineService.GetPipelineJob``. + name (:class:`str`): + Required. The name of the PipelineJob resource. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.PipelineJob: + An instance of a machine learning + PipelineJob. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.GetPipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_pipeline_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_pipeline_jobs( + self, + request: pipeline_service.ListPipelineJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListPipelineJobsAsyncPager: + r"""Lists PipelineJobs in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListPipelineJobsRequest`): + The request object. Request message for + ``PipelineService.ListPipelineJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + PipelineJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.pipeline_service.pagers.ListPipelineJobsAsyncPager: + Response message for + ``PipelineService.ListPipelineJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.ListPipelineJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_pipeline_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListPipelineJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_pipeline_job( + self, + request: pipeline_service.DeletePipelineJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a PipelineJob. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeletePipelineJobRequest`): + The request object. Request message for + ``PipelineService.DeletePipelineJob``. + name (:class:`str`): + Required. The name of the PipelineJob resource to be + deleted. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.DeletePipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_pipeline_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_pipeline_job( + self, + request: pipeline_service.CancelPipelineJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a PipelineJob. Starts asynchronous cancellation on the + PipelineJob. The server makes a best effort to cancel the + pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetPipelineJob`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the PipelineJob is not deleted; instead + it becomes a pipeline with a + ``PipelineJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``PipelineJob.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CancelPipelineJobRequest`): + The request object. Request message for + ``PipelineService.CancelPipelineJob``. + name (:class:`str`): + Required. The name of the PipelineJob to cancel. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.CancelPipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_pipeline_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 9f61aff314..af3da01d41 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -38,6 +38,8 @@ from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_job +from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -169,6 +171,64 @@ def transport(self) -> PipelineServiceTransport: """ return self._transport + @staticmethod + def artifact_path( + project: str, location: str, metadata_store: str, artifact: str, + ) -> str: + """Return a fully-qualified artifact string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) + + @staticmethod + def parse_artifact_path(path: str) -> Dict[str, str]: + """Parse a artifact path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/artifacts/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def context_path( + project: str, location: str, metadata_store: str, context: str, + ) -> str: + """Return a fully-qualified context string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + + @staticmethod + def parse_context_path(path: str) -> Dict[str, str]: + """Parse a context path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def custom_job_path(project: str, location: str, custom_job: str,) -> str: + """Return a fully-qualified custom_job string.""" + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + + @staticmethod + def parse_custom_job_path(path: str) -> Dict[str, str]: + """Parse a custom_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" @@ -185,6 +245,27 @@ def parse_endpoint_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def execution_path( + project: str, location: str, metadata_store: str, execution: str, + ) -> str: + """Return a fully-qualified execution string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) + + @staticmethod + def parse_execution_path(path: str) -> Dict[str, str]: + """Parse a execution path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/executions/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" @@ -201,6 +282,37 @@ def parse_model_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def network_path(project: str, network: str,) -> str: + """Return a fully-qualified network string.""" + return "projects/{project}/global/networks/{network}".format( + project=project, network=network, + ) + + @staticmethod + def parse_network_path(path: str) -> Dict[str, str]: + """Parse a network path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/global/networks/(?P.+?)$", path + ) + return m.groupdict() if m else {} + + @staticmethod + def pipeline_job_path(project: str, location: str, pipeline_job: str,) -> str: + """Return a fully-qualified pipeline_job string.""" + return "projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}".format( + project=project, location=location, pipeline_job=pipeline_job, + ) + + @staticmethod + def parse_pipeline_job_path(path: str) -> Dict[str, str]: + """Parse a pipeline_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/pipelineJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def training_pipeline_path( project: str, location: str, training_pipeline: str, @@ -820,6 +932,437 @@ def cancel_training_pipeline( request, retry=retry, timeout=timeout, metadata=metadata, ) + def create_pipeline_job( + self, + request: pipeline_service.CreatePipelineJobRequest = None, + *, + parent: str = None, + pipeline_job: gca_pipeline_job.PipelineJob = None, + pipeline_job_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_pipeline_job.PipelineJob: + r"""Creates a PipelineJob. A PipelineJob will run + immediately when created. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreatePipelineJobRequest): + The request object. Request message for + ``PipelineService.CreatePipelineJob``. + parent (str): + Required. The resource name of the Location to create + the PipelineJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + pipeline_job (google.cloud.aiplatform_v1beta1.types.PipelineJob): + Required. The PipelineJob to create. + This corresponds to the ``pipeline_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + pipeline_job_id (str): + The ID to use for the PipelineJob, which will become the + final component of the PipelineJob name. If not + provided, an ID will be automatically generated. + + This value should be less than 128 characters, and valid + characters are /[a-z][0-9]-/. + + This corresponds to the ``pipeline_job_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.PipelineJob: + An instance of a machine learning + PipelineJob. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CreatePipelineJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CreatePipelineJobRequest): + request = pipeline_service.CreatePipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if pipeline_job is not None: + request.pipeline_job = pipeline_job + if pipeline_job_id is not None: + request.pipeline_job_id = pipeline_job_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_pipeline_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_pipeline_job( + self, + request: pipeline_service.GetPipelineJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pipeline_job.PipelineJob: + r"""Gets a PipelineJob. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetPipelineJobRequest): + The request object. Request message for + ``PipelineService.GetPipelineJob``. + name (str): + Required. The name of the PipelineJob resource. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.PipelineJob: + An instance of a machine learning + PipelineJob. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.GetPipelineJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.GetPipelineJobRequest): + request = pipeline_service.GetPipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_pipeline_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_pipeline_jobs( + self, + request: pipeline_service.ListPipelineJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListPipelineJobsPager: + r"""Lists PipelineJobs in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListPipelineJobsRequest): + The request object. Request message for + ``PipelineService.ListPipelineJobs``. + parent (str): + Required. The resource name of the Location to list the + PipelineJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.pipeline_service.pagers.ListPipelineJobsPager: + Response message for + ``PipelineService.ListPipelineJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.ListPipelineJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.ListPipelineJobsRequest): + request = pipeline_service.ListPipelineJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_pipeline_jobs] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPipelineJobsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_pipeline_job( + self, + request: pipeline_service.DeletePipelineJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a PipelineJob. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeletePipelineJobRequest): + The request object. Request message for + ``PipelineService.DeletePipelineJob``. + name (str): + Required. The name of the PipelineJob resource to be + deleted. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.DeletePipelineJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.DeletePipelineJobRequest): + request = pipeline_service.DeletePipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_pipeline_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def cancel_pipeline_job( + self, + request: pipeline_service.CancelPipelineJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a PipelineJob. Starts asynchronous cancellation on the + PipelineJob. The server makes a best effort to cancel the + pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetPipelineJob`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the PipelineJob is not deleted; instead + it becomes a pipeline with a + ``PipelineJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``PipelineJob.state`` + is set to ``CANCELLED``. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CancelPipelineJobRequest): + The request object. Request message for + ``PipelineService.CancelPipelineJob``. + name (str): + Required. The name of the PipelineJob to cancel. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CancelPipelineJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CancelPipelineJobRequest): + request = pipeline_service.CancelPipelineJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.cancel_pipeline_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index db2b4dd3a1..0a4aa3bbc5 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -26,6 +26,7 @@ Optional, ) +from google.cloud.aiplatform_v1beta1.types import pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -160,3 +161,131 @@ async def async_generator(): def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPipelineJobsPager: + """A pager for iterating through ``list_pipeline_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``pipeline_jobs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListPipelineJobs`` requests and continue to iterate + through the ``pipeline_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., pipeline_service.ListPipelineJobsResponse], + request: pipeline_service.ListPipelineJobsRequest, + response: pipeline_service.ListPipelineJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListPipelineJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = pipeline_service.ListPipelineJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[pipeline_service.ListPipelineJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[pipeline_job.PipelineJob]: + for page in self.pages: + yield from page.pipeline_jobs + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPipelineJobsAsyncPager: + """A pager for iterating through ``list_pipeline_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``pipeline_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListPipelineJobs`` requests and continue to iterate + through the ``pipeline_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[pipeline_service.ListPipelineJobsResponse]], + request: pipeline_service.ListPipelineJobsRequest, + response: pipeline_service.ListPipelineJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListPipelineJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = pipeline_service.ListPipelineJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[pipeline_service.ListPipelineJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[pipeline_job.PipelineJob]: + async def async_generator(): + async for page in self.pages: + for response in page.pipeline_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 886219917f..70ad468804 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -26,6 +26,8 @@ from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore +from google.cloud.aiplatform_v1beta1.types import pipeline_job +from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline from google.cloud.aiplatform_v1beta1.types import ( @@ -138,6 +140,21 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), + self.create_pipeline_job: gapic_v1.method.wrap_method( + self.create_pipeline_job, default_timeout=None, client_info=client_info, + ), + self.get_pipeline_job: gapic_v1.method.wrap_method( + self.get_pipeline_job, default_timeout=None, client_info=client_info, + ), + self.list_pipeline_jobs: gapic_v1.method.wrap_method( + self.list_pipeline_jobs, default_timeout=None, client_info=client_info, + ), + self.delete_pipeline_job: gapic_v1.method.wrap_method( + self.delete_pipeline_job, default_timeout=None, client_info=client_info, + ), + self.cancel_pipeline_job: gapic_v1.method.wrap_method( + self.cancel_pipeline_job, default_timeout=None, client_info=client_info, + ), } @property @@ -199,5 +216,57 @@ def cancel_training_pipeline( ]: raise NotImplementedError() + @property + def create_pipeline_job( + self, + ) -> typing.Callable[ + [pipeline_service.CreatePipelineJobRequest], + typing.Union[ + gca_pipeline_job.PipelineJob, typing.Awaitable[gca_pipeline_job.PipelineJob] + ], + ]: + raise NotImplementedError() + + @property + def get_pipeline_job( + self, + ) -> typing.Callable[ + [pipeline_service.GetPipelineJobRequest], + typing.Union[ + pipeline_job.PipelineJob, typing.Awaitable[pipeline_job.PipelineJob] + ], + ]: + raise NotImplementedError() + + @property + def list_pipeline_jobs( + self, + ) -> typing.Callable[ + [pipeline_service.ListPipelineJobsRequest], + typing.Union[ + pipeline_service.ListPipelineJobsResponse, + typing.Awaitable[pipeline_service.ListPipelineJobsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_pipeline_job( + self, + ) -> typing.Callable[ + [pipeline_service.DeletePipelineJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_pipeline_job( + self, + ) -> typing.Callable[ + [pipeline_service.CancelPipelineJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + __all__ = ("PipelineServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 8004a9a0a7..dfe8ad5e89 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -27,6 +27,8 @@ import grpc # type: ignore +from google.cloud.aiplatform_v1beta1.types import pipeline_job +from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline from google.cloud.aiplatform_v1beta1.types import ( @@ -397,5 +399,153 @@ def cancel_training_pipeline( ) return self._stubs["cancel_training_pipeline"] + @property + def create_pipeline_job( + self, + ) -> Callable[ + [pipeline_service.CreatePipelineJobRequest], gca_pipeline_job.PipelineJob + ]: + r"""Return a callable for the create pipeline job method over gRPC. + + Creates a PipelineJob. A PipelineJob will run + immediately when created. + + Returns: + Callable[[~.CreatePipelineJobRequest], + ~.PipelineJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_pipeline_job" not in self._stubs: + self._stubs["create_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CreatePipelineJob", + request_serializer=pipeline_service.CreatePipelineJobRequest.serialize, + response_deserializer=gca_pipeline_job.PipelineJob.deserialize, + ) + return self._stubs["create_pipeline_job"] + + @property + def get_pipeline_job( + self, + ) -> Callable[[pipeline_service.GetPipelineJobRequest], pipeline_job.PipelineJob]: + r"""Return a callable for the get pipeline job method over gRPC. + + Gets a PipelineJob. + + Returns: + Callable[[~.GetPipelineJobRequest], + ~.PipelineJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_pipeline_job" not in self._stubs: + self._stubs["get_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/GetPipelineJob", + request_serializer=pipeline_service.GetPipelineJobRequest.serialize, + response_deserializer=pipeline_job.PipelineJob.deserialize, + ) + return self._stubs["get_pipeline_job"] + + @property + def list_pipeline_jobs( + self, + ) -> Callable[ + [pipeline_service.ListPipelineJobsRequest], + pipeline_service.ListPipelineJobsResponse, + ]: + r"""Return a callable for the list pipeline jobs method over gRPC. + + Lists PipelineJobs in a Location. + + Returns: + Callable[[~.ListPipelineJobsRequest], + ~.ListPipelineJobsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_pipeline_jobs" not in self._stubs: + self._stubs["list_pipeline_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/ListPipelineJobs", + request_serializer=pipeline_service.ListPipelineJobsRequest.serialize, + response_deserializer=pipeline_service.ListPipelineJobsResponse.deserialize, + ) + return self._stubs["list_pipeline_jobs"] + + @property + def delete_pipeline_job( + self, + ) -> Callable[[pipeline_service.DeletePipelineJobRequest], operations.Operation]: + r"""Return a callable for the delete pipeline job method over gRPC. + + Deletes a PipelineJob. + + Returns: + Callable[[~.DeletePipelineJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_pipeline_job" not in self._stubs: + self._stubs["delete_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/DeletePipelineJob", + request_serializer=pipeline_service.DeletePipelineJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_pipeline_job"] + + @property + def cancel_pipeline_job( + self, + ) -> Callable[[pipeline_service.CancelPipelineJobRequest], empty.Empty]: + r"""Return a callable for the cancel pipeline job method over gRPC. + + Cancels a PipelineJob. Starts asynchronous cancellation on the + PipelineJob. The server makes a best effort to cancel the + pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetPipelineJob`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the PipelineJob is not deleted; instead + it becomes a pipeline with a + ``PipelineJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``PipelineJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelPipelineJobRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_pipeline_job" not in self._stubs: + self._stubs["cancel_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CancelPipelineJob", + request_serializer=pipeline_service.CancelPipelineJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_pipeline_job"] + __all__ = ("PipelineServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index a268ec1cd2..f813118e1e 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -28,6 +28,8 @@ import grpc # type: ignore from grpc.experimental import aio # type: ignore +from google.cloud.aiplatform_v1beta1.types import pipeline_job +from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline from google.cloud.aiplatform_v1beta1.types import ( @@ -406,5 +408,158 @@ def cancel_training_pipeline( ) return self._stubs["cancel_training_pipeline"] + @property + def create_pipeline_job( + self, + ) -> Callable[ + [pipeline_service.CreatePipelineJobRequest], + Awaitable[gca_pipeline_job.PipelineJob], + ]: + r"""Return a callable for the create pipeline job method over gRPC. + + Creates a PipelineJob. A PipelineJob will run + immediately when created. + + Returns: + Callable[[~.CreatePipelineJobRequest], + Awaitable[~.PipelineJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_pipeline_job" not in self._stubs: + self._stubs["create_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CreatePipelineJob", + request_serializer=pipeline_service.CreatePipelineJobRequest.serialize, + response_deserializer=gca_pipeline_job.PipelineJob.deserialize, + ) + return self._stubs["create_pipeline_job"] + + @property + def get_pipeline_job( + self, + ) -> Callable[ + [pipeline_service.GetPipelineJobRequest], Awaitable[pipeline_job.PipelineJob] + ]: + r"""Return a callable for the get pipeline job method over gRPC. + + Gets a PipelineJob. + + Returns: + Callable[[~.GetPipelineJobRequest], + Awaitable[~.PipelineJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_pipeline_job" not in self._stubs: + self._stubs["get_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/GetPipelineJob", + request_serializer=pipeline_service.GetPipelineJobRequest.serialize, + response_deserializer=pipeline_job.PipelineJob.deserialize, + ) + return self._stubs["get_pipeline_job"] + + @property + def list_pipeline_jobs( + self, + ) -> Callable[ + [pipeline_service.ListPipelineJobsRequest], + Awaitable[pipeline_service.ListPipelineJobsResponse], + ]: + r"""Return a callable for the list pipeline jobs method over gRPC. + + Lists PipelineJobs in a Location. + + Returns: + Callable[[~.ListPipelineJobsRequest], + Awaitable[~.ListPipelineJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_pipeline_jobs" not in self._stubs: + self._stubs["list_pipeline_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/ListPipelineJobs", + request_serializer=pipeline_service.ListPipelineJobsRequest.serialize, + response_deserializer=pipeline_service.ListPipelineJobsResponse.deserialize, + ) + return self._stubs["list_pipeline_jobs"] + + @property + def delete_pipeline_job( + self, + ) -> Callable[ + [pipeline_service.DeletePipelineJobRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete pipeline job method over gRPC. + + Deletes a PipelineJob. + + Returns: + Callable[[~.DeletePipelineJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_pipeline_job" not in self._stubs: + self._stubs["delete_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/DeletePipelineJob", + request_serializer=pipeline_service.DeletePipelineJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_pipeline_job"] + + @property + def cancel_pipeline_job( + self, + ) -> Callable[[pipeline_service.CancelPipelineJobRequest], Awaitable[empty.Empty]]: + r"""Return a callable for the cancel pipeline job method over gRPC. + + Cancels a PipelineJob. Starts asynchronous cancellation on the + PipelineJob. The server makes a best effort to cancel the + pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetPipelineJob`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the PipelineJob is not deleted; instead + it becomes a pipeline with a + ``PipelineJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``PipelineJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelPipelineJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_pipeline_job" not in self._stubs: + self._stubs["cancel_pipeline_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CancelPipelineJob", + request_serializer=pipeline_service.CancelPipelineJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_pipeline_job"] + __all__ = ("PipelineServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/__init__.py new file mode 100644 index 0000000000..70277571f7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import TensorboardServiceClient +from .async_client import TensorboardServiceAsyncClient + +__all__ = ( + "TensorboardServiceClient", + "TensorboardServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py new file mode 100644 index 0000000000..144488ff29 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py @@ -0,0 +1,2346 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, AsyncIterable, Awaitable, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import pagers +from google.cloud.aiplatform_v1beta1.types import encryption_spec +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard as gca_tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_data +from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_experiment as gca_tensorboard_experiment, +) +from google.cloud.aiplatform_v1beta1.types import tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_service +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_time_series as gca_tensorboard_time_series, +) +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import TensorboardServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import TensorboardServiceGrpcAsyncIOTransport +from .client import TensorboardServiceClient + + +class TensorboardServiceAsyncClient: + """TensorboardService""" + + _client: TensorboardServiceClient + + DEFAULT_ENDPOINT = TensorboardServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = TensorboardServiceClient.DEFAULT_MTLS_ENDPOINT + + tensorboard_path = staticmethod(TensorboardServiceClient.tensorboard_path) + parse_tensorboard_path = staticmethod( + TensorboardServiceClient.parse_tensorboard_path + ) + tensorboard_experiment_path = staticmethod( + TensorboardServiceClient.tensorboard_experiment_path + ) + parse_tensorboard_experiment_path = staticmethod( + TensorboardServiceClient.parse_tensorboard_experiment_path + ) + tensorboard_run_path = staticmethod(TensorboardServiceClient.tensorboard_run_path) + parse_tensorboard_run_path = staticmethod( + TensorboardServiceClient.parse_tensorboard_run_path + ) + tensorboard_time_series_path = staticmethod( + TensorboardServiceClient.tensorboard_time_series_path + ) + parse_tensorboard_time_series_path = staticmethod( + TensorboardServiceClient.parse_tensorboard_time_series_path + ) + + common_billing_account_path = staticmethod( + TensorboardServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + TensorboardServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(TensorboardServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + TensorboardServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + TensorboardServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + TensorboardServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(TensorboardServiceClient.common_project_path) + parse_common_project_path = staticmethod( + TensorboardServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(TensorboardServiceClient.common_location_path) + parse_common_location_path = staticmethod( + TensorboardServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TensorboardServiceAsyncClient: The constructed client. + """ + return TensorboardServiceClient.from_service_account_info.__func__(TensorboardServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TensorboardServiceAsyncClient: The constructed client. + """ + return TensorboardServiceClient.from_service_account_file.__func__(TensorboardServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> TensorboardServiceTransport: + """Return the transport used by the client instance. + + Returns: + TensorboardServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(TensorboardServiceClient).get_transport_class, + type(TensorboardServiceClient), + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, TensorboardServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the tensorboard service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.TensorboardServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = TensorboardServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_tensorboard( + self, + request: tensorboard_service.CreateTensorboardRequest = None, + *, + parent: str = None, + tensorboard: gca_tensorboard.Tensorboard = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a Tensorboard. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateTensorboardRequest`): + The request object. Request message for + ``TensorboardService.CreateTensorboard``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Tensorboard in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard (:class:`google.cloud.aiplatform_v1beta1.types.Tensorboard`): + Required. The Tensorboard to create. + This corresponds to the ``tensorboard`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Tensorboard` Tensorboard is a physical database that stores users’ training metrics. + A default Tensorboard is provided in each region of a + GCP project. If needed users can also create extra + Tensorboards in their projects. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, tensorboard]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.CreateTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard is not None: + request.tensorboard = tensorboard + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_tensorboard, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_tensorboard.Tensorboard, + metadata_type=tensorboard_service.CreateTensorboardOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_tensorboard( + self, + request: tensorboard_service.GetTensorboardRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard.Tensorboard: + r"""Gets a Tensorboard. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetTensorboardRequest`): + The request object. Request message for + ``TensorboardService.GetTensorboard``. + name (:class:`str`): + Required. The name of the Tensorboard resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Tensorboard: + Tensorboard is a physical database + that stores users’ training metrics. A + default Tensorboard is provided in each + region of a GCP project. If needed users + can also create extra Tensorboards in + their projects. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.GetTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_tensorboard, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def update_tensorboard( + self, + request: tensorboard_service.UpdateTensorboardRequest = None, + *, + tensorboard: gca_tensorboard.Tensorboard = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates a Tensorboard. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateTensorboardRequest`): + The request object. Request message for + ``TensorboardService.UpdateTensorboard``. + tensorboard (:class:`google.cloud.aiplatform_v1beta1.types.Tensorboard`): + Required. The Tensorboard's ``name`` field is used to + identify the Tensorboard to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``tensorboard`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. Field mask is used to specify the fields to be + overwritten in the Tensorboard resource by the update. + The fields specified in the update_mask are relative to + the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then all fields will be overwritten if + new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Tensorboard` Tensorboard is a physical database that stores users’ training metrics. + A default Tensorboard is provided in each region of a + GCP project. If needed users can also create extra + Tensorboards in their projects. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.UpdateTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard is not None: + request.tensorboard = tensorboard + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_tensorboard, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard.name", request.tensorboard.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_tensorboard.Tensorboard, + metadata_type=tensorboard_service.UpdateTensorboardOperationMetadata, + ) + + # Done; return the response. + return response + + async def list_tensorboards( + self, + request: tensorboard_service.ListTensorboardsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardsAsyncPager: + r"""Lists Tensorboards in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardsRequest`): + The request object. Request message for + ``TensorboardService.ListTensorboards``. + parent (:class:`str`): + Required. The resource name of the + Location to list Tensorboards. Format: + 'projects/{project}/locations/{location}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardsAsyncPager: + Response message for + ``TensorboardService.ListTensorboards``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.ListTensorboardsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_tensorboards, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTensorboardsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_tensorboard( + self, + request: tensorboard_service.DeleteTensorboardRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Tensorboard. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteTensorboardRequest`): + The request object. Request message for + ``TensorboardService.DeleteTensorboard``. + name (:class:`str`): + Required. The name of the Tensorboard to be deleted. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.DeleteTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_tensorboard, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def create_tensorboard_experiment( + self, + request: tensorboard_service.CreateTensorboardExperimentRequest = None, + *, + parent: str = None, + tensorboard_experiment: gca_tensorboard_experiment.TensorboardExperiment = None, + tensorboard_experiment_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_experiment.TensorboardExperiment: + r"""Creates a TensorboardExperiment. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateTensorboardExperimentRequest`): + The request object. Request message for + ``TensorboardService.CreateTensorboardExperiment``. + parent (:class:`str`): + Required. The resource name of the Tensorboard to create + the TensorboardExperiment in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_experiment (:class:`google.cloud.aiplatform_v1beta1.types.TensorboardExperiment`): + The TensorboardExperiment to create. + This corresponds to the ``tensorboard_experiment`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_experiment_id (:class:`str`): + Required. The ID to use for the Tensorboard experiment, + which will become the final component of the Tensorboard + experiment's resource name. + + This value should be 1-128 characters, and valid + characters are /[a-z][0-9]-/. + + This corresponds to the ``tensorboard_experiment_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardExperiment: + A TensorboardExperiment is a group of + TensorboardRuns, that are typically the + results of a training job run, in a + Tensorboard. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any( + [parent, tensorboard_experiment, tensorboard_experiment_id] + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.CreateTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard_experiment is not None: + request.tensorboard_experiment = tensorboard_experiment + if tensorboard_experiment_id is not None: + request.tensorboard_experiment_id = tensorboard_experiment_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_tensorboard_experiment, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_tensorboard_experiment( + self, + request: tensorboard_service.GetTensorboardExperimentRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_experiment.TensorboardExperiment: + r"""Gets a TensorboardExperiment. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetTensorboardExperimentRequest`): + The request object. Request message for + ``TensorboardService.GetTensorboardExperiment``. + name (:class:`str`): + Required. The name of the TensorboardExperiment + resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardExperiment: + A TensorboardExperiment is a group of + TensorboardRuns, that are typically the + results of a training job run, in a + Tensorboard. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.GetTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_tensorboard_experiment, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def update_tensorboard_experiment( + self, + request: tensorboard_service.UpdateTensorboardExperimentRequest = None, + *, + tensorboard_experiment: gca_tensorboard_experiment.TensorboardExperiment = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_experiment.TensorboardExperiment: + r"""Updates a TensorboardExperiment. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateTensorboardExperimentRequest`): + The request object. Request message for + ``TensorboardService.UpdateTensorboardExperiment``. + tensorboard_experiment (:class:`google.cloud.aiplatform_v1beta1.types.TensorboardExperiment`): + Required. The TensorboardExperiment's ``name`` field is + used to identify the TensorboardExperiment to be + updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``tensorboard_experiment`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardExperiment resource by the + update. The fields specified in the update_mask are + relative to the resource, not the full request. A field + will be overwritten if it is in the mask. If the user + does not provide a mask then all fields will be + overwritten if new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardExperiment: + A TensorboardExperiment is a group of + TensorboardRuns, that are typically the + results of a training job run, in a + Tensorboard. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_experiment, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.UpdateTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_experiment is not None: + request.tensorboard_experiment = tensorboard_experiment + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_tensorboard_experiment, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_experiment.name", request.tensorboard_experiment.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_tensorboard_experiments( + self, + request: tensorboard_service.ListTensorboardExperimentsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardExperimentsAsyncPager: + r"""Lists TensorboardExperiments in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsRequest`): + The request object. Request message for + ``TensorboardService.ListTensorboardExperiments``. + parent (:class:`str`): + Required. The resource name of the + Tensorboard to list + TensorboardExperiments. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardExperimentsAsyncPager: + Response message for + ``TensorboardService.ListTensorboardExperiments``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.ListTensorboardExperimentsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_tensorboard_experiments, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTensorboardExperimentsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_tensorboard_experiment( + self, + request: tensorboard_service.DeleteTensorboardExperimentRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a TensorboardExperiment. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteTensorboardExperimentRequest`): + The request object. Request message for + ``TensorboardService.DeleteTensorboardExperiment``. + name (:class:`str`): + Required. The name of the TensorboardExperiment to be + deleted. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.DeleteTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_tensorboard_experiment, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def create_tensorboard_run( + self, + request: tensorboard_service.CreateTensorboardRunRequest = None, + *, + parent: str = None, + tensorboard_run: gca_tensorboard_run.TensorboardRun = None, + tensorboard_run_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_run.TensorboardRun: + r"""Creates a TensorboardRun. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateTensorboardRunRequest`): + The request object. Request message for + ``TensorboardService.CreateTensorboardRun``. + parent (:class:`str`): + Required. The resource name of the Tensorboard to create + the TensorboardRun in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_run (:class:`google.cloud.aiplatform_v1beta1.types.TensorboardRun`): + Required. The TensorboardRun to + create. + + This corresponds to the ``tensorboard_run`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_run_id (:class:`str`): + Required. The ID to use for the Tensorboard run, which + will become the final component of the Tensorboard run's + resource name. + + This value should be 1-128 characters, and valid + characters are /[a-z][0-9]-/. + + This corresponds to the ``tensorboard_run_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardRun: + TensorboardRun maps to a specific + execution of a training job with a given + set of hyperparameter values, model + definition, dataset, etc + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.CreateTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard_run is not None: + request.tensorboard_run = tensorboard_run + if tensorboard_run_id is not None: + request.tensorboard_run_id = tensorboard_run_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_tensorboard_run, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_tensorboard_run( + self, + request: tensorboard_service.GetTensorboardRunRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_run.TensorboardRun: + r"""Gets a TensorboardRun. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetTensorboardRunRequest`): + The request object. Request message for + ``TensorboardService.GetTensorboardRun``. + name (:class:`str`): + Required. The name of the TensorboardRun resource. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardRun: + TensorboardRun maps to a specific + execution of a training job with a given + set of hyperparameter values, model + definition, dataset, etc + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.GetTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_tensorboard_run, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def update_tensorboard_run( + self, + request: tensorboard_service.UpdateTensorboardRunRequest = None, + *, + tensorboard_run: gca_tensorboard_run.TensorboardRun = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_run.TensorboardRun: + r"""Updates a TensorboardRun. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateTensorboardRunRequest`): + The request object. Request message for + ``TensorboardService.UpdateTensorboardRun``. + tensorboard_run (:class:`google.cloud.aiplatform_v1beta1.types.TensorboardRun`): + Required. The TensorboardRun's ``name`` field is used to + identify the TensorboardRun to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``tensorboard_run`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardRun resource by the + update. The fields specified in the update_mask are + relative to the resource, not the full request. A field + will be overwritten if it is in the mask. If the user + does not provide a mask then all fields will be + overwritten if new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardRun: + TensorboardRun maps to a specific + execution of a training job with a given + set of hyperparameter values, model + definition, dataset, etc + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_run, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.UpdateTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_run is not None: + request.tensorboard_run = tensorboard_run + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_tensorboard_run, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_run.name", request.tensorboard_run.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_tensorboard_runs( + self, + request: tensorboard_service.ListTensorboardRunsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardRunsAsyncPager: + r"""Lists TensorboardRuns in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsRequest`): + The request object. Request message for + ``TensorboardService.ListTensorboardRuns``. + parent (:class:`str`): + Required. The resource name of the + Tensorboard to list TensorboardRuns. + Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardRunsAsyncPager: + Response message for + ``TensorboardService.ListTensorboardRuns``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.ListTensorboardRunsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_tensorboard_runs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTensorboardRunsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_tensorboard_run( + self, + request: tensorboard_service.DeleteTensorboardRunRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a TensorboardRun. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteTensorboardRunRequest`): + The request object. Request message for + ``TensorboardService.DeleteTensorboardRun``. + name (:class:`str`): + Required. The name of the TensorboardRun to be deleted. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.DeleteTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_tensorboard_run, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def create_tensorboard_time_series( + self, + request: tensorboard_service.CreateTensorboardTimeSeriesRequest = None, + *, + parent: str = None, + tensorboard_time_series: gca_tensorboard_time_series.TensorboardTimeSeries = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_time_series.TensorboardTimeSeries: + r"""Creates a TensorboardTimeSeries. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateTensorboardTimeSeriesRequest`): + The request object. Request message for + ``TensorboardService.CreateTensorboardTimeSeries``. + parent (:class:`str`): + Required. The resource name of the TensorboardRun to + create the TensorboardTimeSeries in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_time_series (:class:`google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries`): + Required. The TensorboardTimeSeries + to create. + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries: + TensorboardTimeSeries maps to times + series produced in training runs + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, tensorboard_time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.CreateTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_tensorboard_time_series, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_tensorboard_time_series( + self, + request: tensorboard_service.GetTensorboardTimeSeriesRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_time_series.TensorboardTimeSeries: + r"""Gets a TensorboardTimeSeries. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetTensorboardTimeSeriesRequest`): + The request object. Request message for + ``TensorboardService.GetTensorboardTimeSeries``. + name (:class:`str`): + Required. The name of the TensorboardTimeSeries + resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries: + TensorboardTimeSeries maps to times + series produced in training runs + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_tensorboard_time_series, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def update_tensorboard_time_series( + self, + request: tensorboard_service.UpdateTensorboardTimeSeriesRequest = None, + *, + tensorboard_time_series: gca_tensorboard_time_series.TensorboardTimeSeries = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_time_series.TensorboardTimeSeries: + r"""Updates a TensorboardTimeSeries. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateTensorboardTimeSeriesRequest`): + The request object. Request message for + ``TensorboardService.UpdateTensorboardTimeSeries``. + tensorboard_time_series (:class:`google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries`): + Required. The TensorboardTimeSeries' ``name`` field is + used to identify the TensorboardTimeSeries to be + updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardTimeSeries resource by the + update. The fields specified in the update_mask are + relative to the resource, not the full request. A field + will be overwritten if it is in the mask. If the user + does not provide a mask then all fields will be + overwritten if new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries: + TensorboardTimeSeries maps to times + series produced in training runs + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_time_series, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.UpdateTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_tensorboard_time_series, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "tensorboard_time_series.name", + request.tensorboard_time_series.name, + ), + ) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_tensorboard_time_series( + self, + request: tensorboard_service.ListTensorboardTimeSeriesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardTimeSeriesAsyncPager: + r"""Lists TensorboardTimeSeries in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesRequest`): + The request object. Request message for + ``TensorboardService.ListTensorboardTimeSeries``. + parent (:class:`str`): + Required. The resource name of the + TensorboardRun to list + TensorboardTimeSeries. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardTimeSeriesAsyncPager: + Response message for + ``TensorboardService.ListTensorboardTimeSeries``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_tensorboard_time_series, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTensorboardTimeSeriesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_tensorboard_time_series( + self, + request: tensorboard_service.DeleteTensorboardTimeSeriesRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a TensorboardTimeSeries. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteTensorboardTimeSeriesRequest`): + The request object. Request message for + ``TensorboardService.DeleteTensorboardTimeSeries``. + name (:class:`str`): + Required. The name of the TensorboardTimeSeries to be + deleted. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.DeleteTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_tensorboard_time_series, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def read_tensorboard_time_series_data( + self, + request: tensorboard_service.ReadTensorboardTimeSeriesDataRequest = None, + *, + tensorboard_time_series: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_service.ReadTensorboardTimeSeriesDataResponse: + r"""Reads a TensorboardTimeSeries' data. Data is returned in + paginated responses. By default, if the number of data points + stored is less than 1000, all data will be returned. Otherwise, + 1000 data points will be randomly selected from this time series + and returned. This value can be changed by changing + max_data_points. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ReadTensorboardTimeSeriesDataRequest`): + The request object. Request message for + ``TensorboardService.ReadTensorboardTimeSeriesData``. + tensorboard_time_series (:class:`str`): + Required. The resource name of the TensorboardTimeSeries + to read data from. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ReadTensorboardTimeSeriesDataResponse: + Response message for + ``TensorboardService.ReadTensorboardTimeSeriesData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.read_tensorboard_time_series_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_time_series", request.tensorboard_time_series),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def read_tensorboard_blob_data( + self, + request: tensorboard_service.ReadTensorboardBlobDataRequest = None, + *, + time_series: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Awaitable[AsyncIterable[tensorboard_service.ReadTensorboardBlobDataResponse]]: + r"""Gets bytes of TensorboardBlobs. + This is to allow reading blob data stored in consumer + project's Cloud Storage bucket without users having to + obtain Cloud Storage access permission. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ReadTensorboardBlobDataRequest`): + The request object. Request message for + ``TensorboardService.ReadTensorboardBlobData``. + time_series (:class:`str`): + Required. The resource name of the TensorboardTimeSeries + to list Blobs. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}' + + This corresponds to the ``time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + AsyncIterable[google.cloud.aiplatform_v1beta1.types.ReadTensorboardBlobDataResponse]: + Response message for + ``TensorboardService.ReadTensorboardBlobData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.ReadTensorboardBlobDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if time_series is not None: + request.time_series = time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.read_tensorboard_blob_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("time_series", request.time_series),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def write_tensorboard_run_data( + self, + request: tensorboard_service.WriteTensorboardRunDataRequest = None, + *, + tensorboard_run: str = None, + time_series_data: Sequence[tensorboard_data.TimeSeriesData] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_service.WriteTensorboardRunDataResponse: + r"""Write time series data points into multiple + TensorboardTimeSeries under a TensorboardRun. If any + data fail to be ingested, an error will be returned. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.WriteTensorboardRunDataRequest`): + The request object. Request message for + ``TensorboardService.WriteTensorboardRunData``. + tensorboard_run (:class:`str`): + Required. The resource name of the TensorboardRun to + write data to. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``tensorboard_run`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + time_series_data (:class:`Sequence[google.cloud.aiplatform_v1beta1.types.TimeSeriesData]`): + Required. The TensorboardTimeSeries + data to write. Values with in a time + series are indexed by their step value. + Repeated writes to the same step will + overwrite the existing value for that + step. + The upper limit of data points per write + request is 5000. + + This corresponds to the ``time_series_data`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.WriteTensorboardRunDataResponse: + Response message for + ``TensorboardService.WriteTensorboardRunData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_run, time_series_data]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.WriteTensorboardRunDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_run is not None: + request.tensorboard_run = tensorboard_run + + if time_series_data: + request.time_series_data.extend(time_series_data) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.write_tensorboard_run_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_run", request.tensorboard_run),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def export_tensorboard_time_series_data( + self, + request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest = None, + *, + tensorboard_time_series: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ExportTensorboardTimeSeriesDataAsyncPager: + r"""Exports a TensorboardTimeSeries' data. Data is + returned in paginated responses. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataRequest`): + The request object. Request message for + ``TensorboardService.ExportTensorboardTimeSeriesData``. + tensorboard_time_series (:class:`str`): + Required. The resource name of the TensorboardTimeSeries + to export data from. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ExportTensorboardTimeSeriesDataAsyncPager: + Response message for + ``TensorboardService.ExportTensorboardTimeSeriesData``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_tensorboard_time_series_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_time_series", request.tensorboard_time_series),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ExportTensorboardTimeSeriesDataAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("TensorboardServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py new file mode 100644 index 0000000000..39339e8e21 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py @@ -0,0 +1,2647 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Iterable, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import pagers +from google.cloud.aiplatform_v1beta1.types import encryption_spec +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard as gca_tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_data +from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_experiment as gca_tensorboard_experiment, +) +from google.cloud.aiplatform_v1beta1.types import tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_service +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_time_series as gca_tensorboard_time_series, +) +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import TensorboardServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import TensorboardServiceGrpcTransport +from .transports.grpc_asyncio import TensorboardServiceGrpcAsyncIOTransport + + +class TensorboardServiceClientMeta(type): + """Metaclass for the TensorboardService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[TensorboardServiceTransport]] + _transport_registry["grpc"] = TensorboardServiceGrpcTransport + _transport_registry["grpc_asyncio"] = TensorboardServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[TensorboardServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class TensorboardServiceClient(metaclass=TensorboardServiceClientMeta): + """TensorboardService""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TensorboardServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TensorboardServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> TensorboardServiceTransport: + """Return the transport used by the client instance. + + Returns: + TensorboardServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def tensorboard_path(project: str, location: str, tensorboard: str,) -> str: + """Return a fully-qualified tensorboard string.""" + return "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( + project=project, location=location, tensorboard=tensorboard, + ) + + @staticmethod + def parse_tensorboard_path(path: str) -> Dict[str, str]: + """Parse a tensorboard path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/tensorboards/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def tensorboard_experiment_path( + project: str, location: str, tensorboard: str, experiment: str, + ) -> str: + """Return a fully-qualified tensorboard_experiment string.""" + return "projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}".format( + project=project, + location=location, + tensorboard=tensorboard, + experiment=experiment, + ) + + @staticmethod + def parse_tensorboard_experiment_path(path: str) -> Dict[str, str]: + """Parse a tensorboard_experiment path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/tensorboards/(?P.+?)/experiments/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def tensorboard_run_path( + project: str, location: str, tensorboard: str, experiment: str, run: str, + ) -> str: + """Return a fully-qualified tensorboard_run string.""" + return "projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}".format( + project=project, + location=location, + tensorboard=tensorboard, + experiment=experiment, + run=run, + ) + + @staticmethod + def parse_tensorboard_run_path(path: str) -> Dict[str, str]: + """Parse a tensorboard_run path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/tensorboards/(?P.+?)/experiments/(?P.+?)/runs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def tensorboard_time_series_path( + project: str, + location: str, + tensorboard: str, + experiment: str, + run: str, + time_series: str, + ) -> str: + """Return a fully-qualified tensorboard_time_series string.""" + return "projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}".format( + project=project, + location=location, + tensorboard=tensorboard, + experiment=experiment, + run=run, + time_series=time_series, + ) + + @staticmethod + def parse_tensorboard_time_series_path(path: str) -> Dict[str, str]: + """Parse a tensorboard_time_series path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/tensorboards/(?P.+?)/experiments/(?P.+?)/runs/(?P.+?)/timeSeries/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, TensorboardServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the tensorboard service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, TensorboardServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, TensorboardServiceTransport): + # transport is a TensorboardServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_tensorboard( + self, + request: tensorboard_service.CreateTensorboardRequest = None, + *, + parent: str = None, + tensorboard: gca_tensorboard.Tensorboard = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Creates a Tensorboard. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateTensorboardRequest): + The request object. Request message for + ``TensorboardService.CreateTensorboard``. + parent (str): + Required. The resource name of the Location to create + the Tensorboard in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard (google.cloud.aiplatform_v1beta1.types.Tensorboard): + Required. The Tensorboard to create. + This corresponds to the ``tensorboard`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Tensorboard` Tensorboard is a physical database that stores users’ training metrics. + A default Tensorboard is provided in each region of a + GCP project. If needed users can also create extra + Tensorboards in their projects. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, tensorboard]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.CreateTensorboardRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.CreateTensorboardRequest): + request = tensorboard_service.CreateTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard is not None: + request.tensorboard = tensorboard + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_tensorboard] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_tensorboard.Tensorboard, + metadata_type=tensorboard_service.CreateTensorboardOperationMetadata, + ) + + # Done; return the response. + return response + + def get_tensorboard( + self, + request: tensorboard_service.GetTensorboardRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard.Tensorboard: + r"""Gets a Tensorboard. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetTensorboardRequest): + The request object. Request message for + ``TensorboardService.GetTensorboard``. + name (str): + Required. The name of the Tensorboard resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Tensorboard: + Tensorboard is a physical database + that stores users’ training metrics. A + default Tensorboard is provided in each + region of a GCP project. If needed users + can also create extra Tensorboards in + their projects. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.GetTensorboardRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.GetTensorboardRequest): + request = tensorboard_service.GetTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_tensorboard] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def update_tensorboard( + self, + request: tensorboard_service.UpdateTensorboardRequest = None, + *, + tensorboard: gca_tensorboard.Tensorboard = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Updates a Tensorboard. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateTensorboardRequest): + The request object. Request message for + ``TensorboardService.UpdateTensorboard``. + tensorboard (google.cloud.aiplatform_v1beta1.types.Tensorboard): + Required. The Tensorboard's ``name`` field is used to + identify the Tensorboard to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``tensorboard`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the Tensorboard resource by the update. + The fields specified in the update_mask are relative to + the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then all fields will be overwritten if + new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Tensorboard` Tensorboard is a physical database that stores users’ training metrics. + A default Tensorboard is provided in each region of a + GCP project. If needed users can also create extra + Tensorboards in their projects. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.UpdateTensorboardRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.UpdateTensorboardRequest): + request = tensorboard_service.UpdateTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard is not None: + request.tensorboard = tensorboard + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_tensorboard] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard.name", request.tensorboard.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + gca_tensorboard.Tensorboard, + metadata_type=tensorboard_service.UpdateTensorboardOperationMetadata, + ) + + # Done; return the response. + return response + + def list_tensorboards( + self, + request: tensorboard_service.ListTensorboardsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardsPager: + r"""Lists Tensorboards in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardsRequest): + The request object. Request message for + ``TensorboardService.ListTensorboards``. + parent (str): + Required. The resource name of the + Location to list Tensorboards. Format: + 'projects/{project}/locations/{location}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardsPager: + Response message for + ``TensorboardService.ListTensorboards``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.ListTensorboardsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.ListTensorboardsRequest): + request = tensorboard_service.ListTensorboardsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_tensorboards] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTensorboardsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_tensorboard( + self, + request: tensorboard_service.DeleteTensorboardRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a Tensorboard. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteTensorboardRequest): + The request object. Request message for + ``TensorboardService.DeleteTensorboard``. + name (str): + Required. The name of the Tensorboard to be deleted. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.DeleteTensorboardRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.DeleteTensorboardRequest): + request = tensorboard_service.DeleteTensorboardRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_tensorboard] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def create_tensorboard_experiment( + self, + request: tensorboard_service.CreateTensorboardExperimentRequest = None, + *, + parent: str = None, + tensorboard_experiment: gca_tensorboard_experiment.TensorboardExperiment = None, + tensorboard_experiment_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_experiment.TensorboardExperiment: + r"""Creates a TensorboardExperiment. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateTensorboardExperimentRequest): + The request object. Request message for + ``TensorboardService.CreateTensorboardExperiment``. + parent (str): + Required. The resource name of the Tensorboard to create + the TensorboardExperiment in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_experiment (google.cloud.aiplatform_v1beta1.types.TensorboardExperiment): + The TensorboardExperiment to create. + This corresponds to the ``tensorboard_experiment`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_experiment_id (str): + Required. The ID to use for the Tensorboard experiment, + which will become the final component of the Tensorboard + experiment's resource name. + + This value should be 1-128 characters, and valid + characters are /[a-z][0-9]-/. + + This corresponds to the ``tensorboard_experiment_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardExperiment: + A TensorboardExperiment is a group of + TensorboardRuns, that are typically the + results of a training job run, in a + Tensorboard. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any( + [parent, tensorboard_experiment, tensorboard_experiment_id] + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.CreateTensorboardExperimentRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.CreateTensorboardExperimentRequest + ): + request = tensorboard_service.CreateTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard_experiment is not None: + request.tensorboard_experiment = tensorboard_experiment + if tensorboard_experiment_id is not None: + request.tensorboard_experiment_id = tensorboard_experiment_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.create_tensorboard_experiment + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_tensorboard_experiment( + self, + request: tensorboard_service.GetTensorboardExperimentRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_experiment.TensorboardExperiment: + r"""Gets a TensorboardExperiment. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetTensorboardExperimentRequest): + The request object. Request message for + ``TensorboardService.GetTensorboardExperiment``. + name (str): + Required. The name of the TensorboardExperiment + resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardExperiment: + A TensorboardExperiment is a group of + TensorboardRuns, that are typically the + results of a training job run, in a + Tensorboard. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.GetTensorboardExperimentRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.GetTensorboardExperimentRequest): + request = tensorboard_service.GetTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.get_tensorboard_experiment + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def update_tensorboard_experiment( + self, + request: tensorboard_service.UpdateTensorboardExperimentRequest = None, + *, + tensorboard_experiment: gca_tensorboard_experiment.TensorboardExperiment = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_experiment.TensorboardExperiment: + r"""Updates a TensorboardExperiment. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateTensorboardExperimentRequest): + The request object. Request message for + ``TensorboardService.UpdateTensorboardExperiment``. + tensorboard_experiment (google.cloud.aiplatform_v1beta1.types.TensorboardExperiment): + Required. The TensorboardExperiment's ``name`` field is + used to identify the TensorboardExperiment to be + updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``tensorboard_experiment`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardExperiment resource by the + update. The fields specified in the update_mask are + relative to the resource, not the full request. A field + will be overwritten if it is in the mask. If the user + does not provide a mask then all fields will be + overwritten if new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardExperiment: + A TensorboardExperiment is a group of + TensorboardRuns, that are typically the + results of a training job run, in a + Tensorboard. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_experiment, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.UpdateTensorboardExperimentRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.UpdateTensorboardExperimentRequest + ): + request = tensorboard_service.UpdateTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_experiment is not None: + request.tensorboard_experiment = tensorboard_experiment + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.update_tensorboard_experiment + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_experiment.name", request.tensorboard_experiment.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_tensorboard_experiments( + self, + request: tensorboard_service.ListTensorboardExperimentsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardExperimentsPager: + r"""Lists TensorboardExperiments in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsRequest): + The request object. Request message for + ``TensorboardService.ListTensorboardExperiments``. + parent (str): + Required. The resource name of the + Tensorboard to list + TensorboardExperiments. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardExperimentsPager: + Response message for + ``TensorboardService.ListTensorboardExperiments``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.ListTensorboardExperimentsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.ListTensorboardExperimentsRequest + ): + request = tensorboard_service.ListTensorboardExperimentsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.list_tensorboard_experiments + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTensorboardExperimentsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_tensorboard_experiment( + self, + request: tensorboard_service.DeleteTensorboardExperimentRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a TensorboardExperiment. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteTensorboardExperimentRequest): + The request object. Request message for + ``TensorboardService.DeleteTensorboardExperiment``. + name (str): + Required. The name of the TensorboardExperiment to be + deleted. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.DeleteTensorboardExperimentRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.DeleteTensorboardExperimentRequest + ): + request = tensorboard_service.DeleteTensorboardExperimentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.delete_tensorboard_experiment + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def create_tensorboard_run( + self, + request: tensorboard_service.CreateTensorboardRunRequest = None, + *, + parent: str = None, + tensorboard_run: gca_tensorboard_run.TensorboardRun = None, + tensorboard_run_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_run.TensorboardRun: + r"""Creates a TensorboardRun. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateTensorboardRunRequest): + The request object. Request message for + ``TensorboardService.CreateTensorboardRun``. + parent (str): + Required. The resource name of the Tensorboard to create + the TensorboardRun in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_run (google.cloud.aiplatform_v1beta1.types.TensorboardRun): + Required. The TensorboardRun to + create. + + This corresponds to the ``tensorboard_run`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_run_id (str): + Required. The ID to use for the Tensorboard run, which + will become the final component of the Tensorboard run's + resource name. + + This value should be 1-128 characters, and valid + characters are /[a-z][0-9]-/. + + This corresponds to the ``tensorboard_run_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardRun: + TensorboardRun maps to a specific + execution of a training job with a given + set of hyperparameter values, model + definition, dataset, etc + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.CreateTensorboardRunRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.CreateTensorboardRunRequest): + request = tensorboard_service.CreateTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard_run is not None: + request.tensorboard_run = tensorboard_run + if tensorboard_run_id is not None: + request.tensorboard_run_id = tensorboard_run_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_tensorboard_run] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_tensorboard_run( + self, + request: tensorboard_service.GetTensorboardRunRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_run.TensorboardRun: + r"""Gets a TensorboardRun. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetTensorboardRunRequest): + The request object. Request message for + ``TensorboardService.GetTensorboardRun``. + name (str): + Required. The name of the TensorboardRun resource. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardRun: + TensorboardRun maps to a specific + execution of a training job with a given + set of hyperparameter values, model + definition, dataset, etc + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.GetTensorboardRunRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.GetTensorboardRunRequest): + request = tensorboard_service.GetTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_tensorboard_run] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def update_tensorboard_run( + self, + request: tensorboard_service.UpdateTensorboardRunRequest = None, + *, + tensorboard_run: gca_tensorboard_run.TensorboardRun = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_run.TensorboardRun: + r"""Updates a TensorboardRun. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateTensorboardRunRequest): + The request object. Request message for + ``TensorboardService.UpdateTensorboardRun``. + tensorboard_run (google.cloud.aiplatform_v1beta1.types.TensorboardRun): + Required. The TensorboardRun's ``name`` field is used to + identify the TensorboardRun to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``tensorboard_run`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardRun resource by the + update. The fields specified in the update_mask are + relative to the resource, not the full request. A field + will be overwritten if it is in the mask. If the user + does not provide a mask then all fields will be + overwritten if new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardRun: + TensorboardRun maps to a specific + execution of a training job with a given + set of hyperparameter values, model + definition, dataset, etc + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_run, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.UpdateTensorboardRunRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.UpdateTensorboardRunRequest): + request = tensorboard_service.UpdateTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_run is not None: + request.tensorboard_run = tensorboard_run + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_tensorboard_run] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_run.name", request.tensorboard_run.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_tensorboard_runs( + self, + request: tensorboard_service.ListTensorboardRunsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardRunsPager: + r"""Lists TensorboardRuns in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsRequest): + The request object. Request message for + ``TensorboardService.ListTensorboardRuns``. + parent (str): + Required. The resource name of the + Tensorboard to list TensorboardRuns. + Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardRunsPager: + Response message for + ``TensorboardService.ListTensorboardRuns``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.ListTensorboardRunsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.ListTensorboardRunsRequest): + request = tensorboard_service.ListTensorboardRunsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_tensorboard_runs] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTensorboardRunsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_tensorboard_run( + self, + request: tensorboard_service.DeleteTensorboardRunRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a TensorboardRun. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteTensorboardRunRequest): + The request object. Request message for + ``TensorboardService.DeleteTensorboardRun``. + name (str): + Required. The name of the TensorboardRun to be deleted. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.DeleteTensorboardRunRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.DeleteTensorboardRunRequest): + request = tensorboard_service.DeleteTensorboardRunRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_tensorboard_run] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def create_tensorboard_time_series( + self, + request: tensorboard_service.CreateTensorboardTimeSeriesRequest = None, + *, + parent: str = None, + tensorboard_time_series: gca_tensorboard_time_series.TensorboardTimeSeries = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_time_series.TensorboardTimeSeries: + r"""Creates a TensorboardTimeSeries. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateTensorboardTimeSeriesRequest): + The request object. Request message for + ``TensorboardService.CreateTensorboardTimeSeries``. + parent (str): + Required. The resource name of the TensorboardRun to + create the TensorboardTimeSeries in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tensorboard_time_series (google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries): + Required. The TensorboardTimeSeries + to create. + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries: + TensorboardTimeSeries maps to times + series produced in training runs + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, tensorboard_time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.CreateTensorboardTimeSeriesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.CreateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.CreateTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.create_tensorboard_time_series + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_tensorboard_time_series( + self, + request: tensorboard_service.GetTensorboardTimeSeriesRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_time_series.TensorboardTimeSeries: + r"""Gets a TensorboardTimeSeries. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetTensorboardTimeSeriesRequest): + The request object. Request message for + ``TensorboardService.GetTensorboardTimeSeries``. + name (str): + Required. The name of the TensorboardTimeSeries + resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries: + TensorboardTimeSeries maps to times + series produced in training runs + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.GetTensorboardTimeSeriesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.GetTensorboardTimeSeriesRequest): + request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.get_tensorboard_time_series + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def update_tensorboard_time_series( + self, + request: tensorboard_service.UpdateTensorboardTimeSeriesRequest = None, + *, + tensorboard_time_series: gca_tensorboard_time_series.TensorboardTimeSeries = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_tensorboard_time_series.TensorboardTimeSeries: + r"""Updates a TensorboardTimeSeries. + + Args: + request (google.cloud.aiplatform_v1beta1.types.UpdateTensorboardTimeSeriesRequest): + The request object. Request message for + ``TensorboardService.UpdateTensorboardTimeSeries``. + tensorboard_time_series (google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries): + Required. The TensorboardTimeSeries' ``name`` field is + used to identify the TensorboardTimeSeries to be + updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardTimeSeries resource by the + update. The fields specified in the update_mask are + relative to the resource, not the full request. A field + will be overwritten if it is in the mask. If the user + does not provide a mask then all fields will be + overwritten if new values are specified. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries: + TensorboardTimeSeries maps to times + series produced in training runs + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_time_series, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.UpdateTensorboardTimeSeriesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.UpdateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.UpdateTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.update_tensorboard_time_series + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + ( + ( + "tensorboard_time_series.name", + request.tensorboard_time_series.name, + ), + ) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_tensorboard_time_series( + self, + request: tensorboard_service.ListTensorboardTimeSeriesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTensorboardTimeSeriesPager: + r"""Lists TensorboardTimeSeries in a Location. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesRequest): + The request object. Request message for + ``TensorboardService.ListTensorboardTimeSeries``. + parent (str): + Required. The resource name of the + TensorboardRun to list + TensorboardTimeSeries. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}' + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ListTensorboardTimeSeriesPager: + Response message for + ``TensorboardService.ListTensorboardTimeSeries``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.ListTensorboardTimeSeriesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.ListTensorboardTimeSeriesRequest + ): + request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.list_tensorboard_time_series + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTensorboardTimeSeriesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_tensorboard_time_series( + self, + request: tensorboard_service.DeleteTensorboardTimeSeriesRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a TensorboardTimeSeries. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteTensorboardTimeSeriesRequest): + The request object. Request message for + ``TensorboardService.DeleteTensorboardTimeSeries``. + name (str): + Required. The name of the TensorboardTimeSeries to be + deleted. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.DeleteTensorboardTimeSeriesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.DeleteTensorboardTimeSeriesRequest + ): + request = tensorboard_service.DeleteTensorboardTimeSeriesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.delete_tensorboard_time_series + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def read_tensorboard_time_series_data( + self, + request: tensorboard_service.ReadTensorboardTimeSeriesDataRequest = None, + *, + tensorboard_time_series: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_service.ReadTensorboardTimeSeriesDataResponse: + r"""Reads a TensorboardTimeSeries' data. Data is returned in + paginated responses. By default, if the number of data points + stored is less than 1000, all data will be returned. Otherwise, + 1000 data points will be randomly selected from this time series + and returned. This value can be changed by changing + max_data_points. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ReadTensorboardTimeSeriesDataRequest): + The request object. Request message for + ``TensorboardService.ReadTensorboardTimeSeriesData``. + tensorboard_time_series (str): + Required. The resource name of the TensorboardTimeSeries + to read data from. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ReadTensorboardTimeSeriesDataResponse: + Response message for + ``TensorboardService.ReadTensorboardTimeSeriesData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.ReadTensorboardTimeSeriesDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.ReadTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.read_tensorboard_time_series_data + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_time_series", request.tensorboard_time_series),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def read_tensorboard_blob_data( + self, + request: tensorboard_service.ReadTensorboardBlobDataRequest = None, + *, + time_series: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[tensorboard_service.ReadTensorboardBlobDataResponse]: + r"""Gets bytes of TensorboardBlobs. + This is to allow reading blob data stored in consumer + project's Cloud Storage bucket without users having to + obtain Cloud Storage access permission. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ReadTensorboardBlobDataRequest): + The request object. Request message for + ``TensorboardService.ReadTensorboardBlobData``. + time_series (str): + Required. The resource name of the TensorboardTimeSeries + to list Blobs. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}' + + This corresponds to the ``time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.cloud.aiplatform_v1beta1.types.ReadTensorboardBlobDataResponse]: + Response message for + ``TensorboardService.ReadTensorboardBlobData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.ReadTensorboardBlobDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.ReadTensorboardBlobDataRequest): + request = tensorboard_service.ReadTensorboardBlobDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if time_series is not None: + request.time_series = time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.read_tensorboard_blob_data + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("time_series", request.time_series),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def write_tensorboard_run_data( + self, + request: tensorboard_service.WriteTensorboardRunDataRequest = None, + *, + tensorboard_run: str = None, + time_series_data: Sequence[tensorboard_data.TimeSeriesData] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tensorboard_service.WriteTensorboardRunDataResponse: + r"""Write time series data points into multiple + TensorboardTimeSeries under a TensorboardRun. If any + data fail to be ingested, an error will be returned. + + Args: + request (google.cloud.aiplatform_v1beta1.types.WriteTensorboardRunDataRequest): + The request object. Request message for + ``TensorboardService.WriteTensorboardRunData``. + tensorboard_run (str): + Required. The resource name of the TensorboardRun to + write data to. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + + This corresponds to the ``tensorboard_run`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + time_series_data (Sequence[google.cloud.aiplatform_v1beta1.types.TimeSeriesData]): + Required. The TensorboardTimeSeries + data to write. Values with in a time + series are indexed by their step value. + Repeated writes to the same step will + overwrite the existing value for that + step. + The upper limit of data points per write + request is 5000. + + This corresponds to the ``time_series_data`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.WriteTensorboardRunDataResponse: + Response message for + ``TensorboardService.WriteTensorboardRunData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_run, time_series_data]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.WriteTensorboardRunDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, tensorboard_service.WriteTensorboardRunDataRequest): + request = tensorboard_service.WriteTensorboardRunDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_run is not None: + request.tensorboard_run = tensorboard_run + if time_series_data is not None: + request.time_series_data = time_series_data + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.write_tensorboard_run_data + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_run", request.tensorboard_run),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def export_tensorboard_time_series_data( + self, + request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest = None, + *, + tensorboard_time_series: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ExportTensorboardTimeSeriesDataPager: + r"""Exports a TensorboardTimeSeries' data. Data is + returned in paginated responses. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataRequest): + The request object. Request message for + ``TensorboardService.ExportTensorboardTimeSeriesData``. + tensorboard_time_series (str): + Required. The resource name of the TensorboardTimeSeries + to export data from. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + + This corresponds to the ``tensorboard_time_series`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.tensorboard_service.pagers.ExportTensorboardTimeSeriesDataPager: + Response message for + ``TensorboardService.ExportTensorboardTimeSeriesData``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tensorboard_time_series]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a tensorboard_service.ExportTensorboardTimeSeriesDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance( + request, tensorboard_service.ExportTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest( + request + ) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if tensorboard_time_series is not None: + request.tensorboard_time_series = tensorboard_time_series + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.export_tensorboard_time_series_data + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_time_series", request.tensorboard_time_series),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ExportTensorboardTimeSeriesDataPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("TensorboardServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py new file mode 100644 index 0000000000..acc2c40676 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py @@ -0,0 +1,700 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.aiplatform_v1beta1.types import tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_data +from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment +from google.cloud.aiplatform_v1beta1.types import tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_service +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series + + +class ListTensorboardsPager: + """A pager for iterating through ``list_tensorboards`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``tensorboards`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTensorboards`` requests and continue to iterate + through the ``tensorboards`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., tensorboard_service.ListTensorboardsResponse], + request: tensorboard_service.ListTensorboardsRequest, + response: tensorboard_service.ListTensorboardsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[tensorboard_service.ListTensorboardsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[tensorboard.Tensorboard]: + for page in self.pages: + yield from page.tensorboards + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTensorboardsAsyncPager: + """A pager for iterating through ``list_tensorboards`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``tensorboards`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTensorboards`` requests and continue to iterate + through the ``tensorboards`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[tensorboard_service.ListTensorboardsResponse]], + request: tensorboard_service.ListTensorboardsRequest, + response: tensorboard_service.ListTensorboardsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[tensorboard_service.ListTensorboardsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[tensorboard.Tensorboard]: + async def async_generator(): + async for page in self.pages: + for response in page.tensorboards: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTensorboardExperimentsPager: + """A pager for iterating through ``list_tensorboard_experiments`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``tensorboard_experiments`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTensorboardExperiments`` requests and continue to iterate + through the ``tensorboard_experiments`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., tensorboard_service.ListTensorboardExperimentsResponse], + request: tensorboard_service.ListTensorboardExperimentsRequest, + response: tensorboard_service.ListTensorboardExperimentsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardExperimentsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[tensorboard_service.ListTensorboardExperimentsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[tensorboard_experiment.TensorboardExperiment]: + for page in self.pages: + yield from page.tensorboard_experiments + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTensorboardExperimentsAsyncPager: + """A pager for iterating through ``list_tensorboard_experiments`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``tensorboard_experiments`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTensorboardExperiments`` requests and continue to iterate + through the ``tensorboard_experiments`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[tensorboard_service.ListTensorboardExperimentsResponse] + ], + request: tensorboard_service.ListTensorboardExperimentsRequest, + response: tensorboard_service.ListTensorboardExperimentsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardExperimentsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[tensorboard_service.ListTensorboardExperimentsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[tensorboard_experiment.TensorboardExperiment]: + async def async_generator(): + async for page in self.pages: + for response in page.tensorboard_experiments: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTensorboardRunsPager: + """A pager for iterating through ``list_tensorboard_runs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``tensorboard_runs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTensorboardRuns`` requests and continue to iterate + through the ``tensorboard_runs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., tensorboard_service.ListTensorboardRunsResponse], + request: tensorboard_service.ListTensorboardRunsRequest, + response: tensorboard_service.ListTensorboardRunsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardRunsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[tensorboard_service.ListTensorboardRunsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[tensorboard_run.TensorboardRun]: + for page in self.pages: + yield from page.tensorboard_runs + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTensorboardRunsAsyncPager: + """A pager for iterating through ``list_tensorboard_runs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``tensorboard_runs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTensorboardRuns`` requests and continue to iterate + through the ``tensorboard_runs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[tensorboard_service.ListTensorboardRunsResponse] + ], + request: tensorboard_service.ListTensorboardRunsRequest, + response: tensorboard_service.ListTensorboardRunsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardRunsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[tensorboard_service.ListTensorboardRunsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[tensorboard_run.TensorboardRun]: + async def async_generator(): + async for page in self.pages: + for response in page.tensorboard_runs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTensorboardTimeSeriesPager: + """A pager for iterating through ``list_tensorboard_time_series`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``tensorboard_time_series`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTensorboardTimeSeries`` requests and continue to iterate + through the ``tensorboard_time_series`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., tensorboard_service.ListTensorboardTimeSeriesResponse], + request: tensorboard_service.ListTensorboardTimeSeriesRequest, + response: tensorboard_service.ListTensorboardTimeSeriesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[tensorboard_service.ListTensorboardTimeSeriesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[tensorboard_time_series.TensorboardTimeSeries]: + for page in self.pages: + yield from page.tensorboard_time_series + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTensorboardTimeSeriesAsyncPager: + """A pager for iterating through ``list_tensorboard_time_series`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``tensorboard_time_series`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTensorboardTimeSeries`` requests and continue to iterate + through the ``tensorboard_time_series`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[tensorboard_service.ListTensorboardTimeSeriesResponse] + ], + request: tensorboard_service.ListTensorboardTimeSeriesRequest, + response: tensorboard_service.ListTensorboardTimeSeriesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[tensorboard_service.ListTensorboardTimeSeriesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[tensorboard_time_series.TensorboardTimeSeries]: + async def async_generator(): + async for page in self.pages: + for response in page.tensorboard_time_series: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ExportTensorboardTimeSeriesDataPager: + """A pager for iterating through ``export_tensorboard_time_series_data`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse` object, and + provides an ``__iter__`` method to iterate through its + ``time_series_data_points`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ExportTensorboardTimeSeriesData`` requests and continue to iterate + through the ``time_series_data_points`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., tensorboard_service.ExportTensorboardTimeSeriesDataResponse + ], + request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest, + response: tensorboard_service.ExportTensorboardTimeSeriesDataResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages( + self, + ) -> Iterable[tensorboard_service.ExportTensorboardTimeSeriesDataResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[tensorboard_data.TimeSeriesDataPoint]: + for page in self.pages: + yield from page.time_series_data_points + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ExportTensorboardTimeSeriesDataAsyncPager: + """A pager for iterating through ``export_tensorboard_time_series_data`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``time_series_data_points`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ExportTensorboardTimeSeriesData`` requests and continue to iterate + through the ``time_series_data_points`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[tensorboard_service.ExportTensorboardTimeSeriesDataResponse] + ], + request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest, + response: tensorboard_service.ExportTensorboardTimeSeriesDataResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[tensorboard_service.ExportTensorboardTimeSeriesDataResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[tensorboard_data.TimeSeriesDataPoint]: + async def async_generator(): + async for page in self.pages: + for response in page.time_series_data_points: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/__init__.py new file mode 100644 index 0000000000..86ffc7d6b2 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/__init__.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import TensorboardServiceTransport +from .grpc import TensorboardServiceGrpcTransport +from .grpc_asyncio import TensorboardServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[TensorboardServiceTransport]] +_transport_registry["grpc"] = TensorboardServiceGrpcTransport +_transport_registry["grpc_asyncio"] = TensorboardServiceGrpcAsyncIOTransport + +__all__ = ( + "TensorboardServiceTransport", + "TensorboardServiceGrpcTransport", + "TensorboardServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py new file mode 100644 index 0000000000..2e2dea1764 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py @@ -0,0 +1,509 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_experiment as gca_tensorboard_experiment, +) +from google.cloud.aiplatform_v1beta1.types import tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_service +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_time_series as gca_tensorboard_time_series, +) +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class TensorboardServiceTransport(abc.ABC): + """Abstract transport class for TensorboardService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=self._scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_tensorboard: gapic_v1.method.wrap_method( + self.create_tensorboard, default_timeout=None, client_info=client_info, + ), + self.get_tensorboard: gapic_v1.method.wrap_method( + self.get_tensorboard, default_timeout=None, client_info=client_info, + ), + self.update_tensorboard: gapic_v1.method.wrap_method( + self.update_tensorboard, default_timeout=None, client_info=client_info, + ), + self.list_tensorboards: gapic_v1.method.wrap_method( + self.list_tensorboards, default_timeout=None, client_info=client_info, + ), + self.delete_tensorboard: gapic_v1.method.wrap_method( + self.delete_tensorboard, default_timeout=None, client_info=client_info, + ), + self.create_tensorboard_experiment: gapic_v1.method.wrap_method( + self.create_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_experiment: gapic_v1.method.wrap_method( + self.get_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_experiment: gapic_v1.method.wrap_method( + self.update_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_experiments: gapic_v1.method.wrap_method( + self.list_tensorboard_experiments, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_experiment: gapic_v1.method.wrap_method( + self.delete_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_run: gapic_v1.method.wrap_method( + self.create_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_run: gapic_v1.method.wrap_method( + self.get_tensorboard_run, default_timeout=None, client_info=client_info, + ), + self.update_tensorboard_run: gapic_v1.method.wrap_method( + self.update_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_runs: gapic_v1.method.wrap_method( + self.list_tensorboard_runs, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_run: gapic_v1.method.wrap_method( + self.delete_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_time_series: gapic_v1.method.wrap_method( + self.create_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_time_series: gapic_v1.method.wrap_method( + self.get_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_time_series: gapic_v1.method.wrap_method( + self.update_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_time_series: gapic_v1.method.wrap_method( + self.list_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_time_series: gapic_v1.method.wrap_method( + self.delete_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_time_series_data: gapic_v1.method.wrap_method( + self.read_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_blob_data: gapic_v1.method.wrap_method( + self.read_tensorboard_blob_data, + default_timeout=None, + client_info=client_info, + ), + self.write_tensorboard_run_data: gapic_v1.method.wrap_method( + self.write_tensorboard_run_data, + default_timeout=None, + client_info=client_info, + ), + self.export_tensorboard_time_series_data: gapic_v1.method.wrap_method( + self.export_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_tensorboard( + self, + ) -> typing.Callable[ + [tensorboard_service.CreateTensorboardRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def get_tensorboard( + self, + ) -> typing.Callable[ + [tensorboard_service.GetTensorboardRequest], + typing.Union[ + tensorboard.Tensorboard, typing.Awaitable[tensorboard.Tensorboard] + ], + ]: + raise NotImplementedError() + + @property + def update_tensorboard( + self, + ) -> typing.Callable[ + [tensorboard_service.UpdateTensorboardRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def list_tensorboards( + self, + ) -> typing.Callable[ + [tensorboard_service.ListTensorboardsRequest], + typing.Union[ + tensorboard_service.ListTensorboardsResponse, + typing.Awaitable[tensorboard_service.ListTensorboardsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_tensorboard( + self, + ) -> typing.Callable[ + [tensorboard_service.DeleteTensorboardRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def create_tensorboard_experiment( + self, + ) -> typing.Callable[ + [tensorboard_service.CreateTensorboardExperimentRequest], + typing.Union[ + gca_tensorboard_experiment.TensorboardExperiment, + typing.Awaitable[gca_tensorboard_experiment.TensorboardExperiment], + ], + ]: + raise NotImplementedError() + + @property + def get_tensorboard_experiment( + self, + ) -> typing.Callable[ + [tensorboard_service.GetTensorboardExperimentRequest], + typing.Union[ + tensorboard_experiment.TensorboardExperiment, + typing.Awaitable[tensorboard_experiment.TensorboardExperiment], + ], + ]: + raise NotImplementedError() + + @property + def update_tensorboard_experiment( + self, + ) -> typing.Callable[ + [tensorboard_service.UpdateTensorboardExperimentRequest], + typing.Union[ + gca_tensorboard_experiment.TensorboardExperiment, + typing.Awaitable[gca_tensorboard_experiment.TensorboardExperiment], + ], + ]: + raise NotImplementedError() + + @property + def list_tensorboard_experiments( + self, + ) -> typing.Callable[ + [tensorboard_service.ListTensorboardExperimentsRequest], + typing.Union[ + tensorboard_service.ListTensorboardExperimentsResponse, + typing.Awaitable[tensorboard_service.ListTensorboardExperimentsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_tensorboard_experiment( + self, + ) -> typing.Callable[ + [tensorboard_service.DeleteTensorboardExperimentRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def create_tensorboard_run( + self, + ) -> typing.Callable[ + [tensorboard_service.CreateTensorboardRunRequest], + typing.Union[ + gca_tensorboard_run.TensorboardRun, + typing.Awaitable[gca_tensorboard_run.TensorboardRun], + ], + ]: + raise NotImplementedError() + + @property + def get_tensorboard_run( + self, + ) -> typing.Callable[ + [tensorboard_service.GetTensorboardRunRequest], + typing.Union[ + tensorboard_run.TensorboardRun, + typing.Awaitable[tensorboard_run.TensorboardRun], + ], + ]: + raise NotImplementedError() + + @property + def update_tensorboard_run( + self, + ) -> typing.Callable[ + [tensorboard_service.UpdateTensorboardRunRequest], + typing.Union[ + gca_tensorboard_run.TensorboardRun, + typing.Awaitable[gca_tensorboard_run.TensorboardRun], + ], + ]: + raise NotImplementedError() + + @property + def list_tensorboard_runs( + self, + ) -> typing.Callable[ + [tensorboard_service.ListTensorboardRunsRequest], + typing.Union[ + tensorboard_service.ListTensorboardRunsResponse, + typing.Awaitable[tensorboard_service.ListTensorboardRunsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_tensorboard_run( + self, + ) -> typing.Callable[ + [tensorboard_service.DeleteTensorboardRunRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def create_tensorboard_time_series( + self, + ) -> typing.Callable[ + [tensorboard_service.CreateTensorboardTimeSeriesRequest], + typing.Union[ + gca_tensorboard_time_series.TensorboardTimeSeries, + typing.Awaitable[gca_tensorboard_time_series.TensorboardTimeSeries], + ], + ]: + raise NotImplementedError() + + @property + def get_tensorboard_time_series( + self, + ) -> typing.Callable[ + [tensorboard_service.GetTensorboardTimeSeriesRequest], + typing.Union[ + tensorboard_time_series.TensorboardTimeSeries, + typing.Awaitable[tensorboard_time_series.TensorboardTimeSeries], + ], + ]: + raise NotImplementedError() + + @property + def update_tensorboard_time_series( + self, + ) -> typing.Callable[ + [tensorboard_service.UpdateTensorboardTimeSeriesRequest], + typing.Union[ + gca_tensorboard_time_series.TensorboardTimeSeries, + typing.Awaitable[gca_tensorboard_time_series.TensorboardTimeSeries], + ], + ]: + raise NotImplementedError() + + @property + def list_tensorboard_time_series( + self, + ) -> typing.Callable[ + [tensorboard_service.ListTensorboardTimeSeriesRequest], + typing.Union[ + tensorboard_service.ListTensorboardTimeSeriesResponse, + typing.Awaitable[tensorboard_service.ListTensorboardTimeSeriesResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_tensorboard_time_series( + self, + ) -> typing.Callable[ + [tensorboard_service.DeleteTensorboardTimeSeriesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def read_tensorboard_time_series_data( + self, + ) -> typing.Callable[ + [tensorboard_service.ReadTensorboardTimeSeriesDataRequest], + typing.Union[ + tensorboard_service.ReadTensorboardTimeSeriesDataResponse, + typing.Awaitable[tensorboard_service.ReadTensorboardTimeSeriesDataResponse], + ], + ]: + raise NotImplementedError() + + @property + def read_tensorboard_blob_data( + self, + ) -> typing.Callable[ + [tensorboard_service.ReadTensorboardBlobDataRequest], + typing.Union[ + tensorboard_service.ReadTensorboardBlobDataResponse, + typing.Awaitable[tensorboard_service.ReadTensorboardBlobDataResponse], + ], + ]: + raise NotImplementedError() + + @property + def write_tensorboard_run_data( + self, + ) -> typing.Callable[ + [tensorboard_service.WriteTensorboardRunDataRequest], + typing.Union[ + tensorboard_service.WriteTensorboardRunDataResponse, + typing.Awaitable[tensorboard_service.WriteTensorboardRunDataResponse], + ], + ]: + raise NotImplementedError() + + @property + def export_tensorboard_time_series_data( + self, + ) -> typing.Callable[ + [tensorboard_service.ExportTensorboardTimeSeriesDataRequest], + typing.Union[ + tensorboard_service.ExportTensorboardTimeSeriesDataResponse, + typing.Awaitable[ + tensorboard_service.ExportTensorboardTimeSeriesDataResponse + ], + ], + ]: + raise NotImplementedError() + + +__all__ = ("TensorboardServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py new file mode 100644 index 0000000000..02f697b2ae --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py @@ -0,0 +1,962 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_experiment as gca_tensorboard_experiment, +) +from google.cloud.aiplatform_v1beta1.types import tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_service +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_time_series as gca_tensorboard_time_series, +) +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import TensorboardServiceTransport, DEFAULT_CLIENT_INFO + + +class TensorboardServiceGrpcTransport(TensorboardServiceTransport): + """gRPC backend transport for TensorboardService. + + TensorboardService + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_tensorboard( + self, + ) -> Callable[[tensorboard_service.CreateTensorboardRequest], operations.Operation]: + r"""Return a callable for the create tensorboard method over gRPC. + + Creates a Tensorboard. + + Returns: + Callable[[~.CreateTensorboardRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard" not in self._stubs: + self._stubs["create_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboard", + request_serializer=tensorboard_service.CreateTensorboardRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_tensorboard"] + + @property + def get_tensorboard( + self, + ) -> Callable[[tensorboard_service.GetTensorboardRequest], tensorboard.Tensorboard]: + r"""Return a callable for the get tensorboard method over gRPC. + + Gets a Tensorboard. + + Returns: + Callable[[~.GetTensorboardRequest], + ~.Tensorboard]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard" not in self._stubs: + self._stubs["get_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboard", + request_serializer=tensorboard_service.GetTensorboardRequest.serialize, + response_deserializer=tensorboard.Tensorboard.deserialize, + ) + return self._stubs["get_tensorboard"] + + @property + def update_tensorboard( + self, + ) -> Callable[[tensorboard_service.UpdateTensorboardRequest], operations.Operation]: + r"""Return a callable for the update tensorboard method over gRPC. + + Updates a Tensorboard. + + Returns: + Callable[[~.UpdateTensorboardRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard" not in self._stubs: + self._stubs["update_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboard", + request_serializer=tensorboard_service.UpdateTensorboardRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["update_tensorboard"] + + @property + def list_tensorboards( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardsRequest], + tensorboard_service.ListTensorboardsResponse, + ]: + r"""Return a callable for the list tensorboards method over gRPC. + + Lists Tensorboards in a Location. + + Returns: + Callable[[~.ListTensorboardsRequest], + ~.ListTensorboardsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboards" not in self._stubs: + self._stubs["list_tensorboards"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboards", + request_serializer=tensorboard_service.ListTensorboardsRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardsResponse.deserialize, + ) + return self._stubs["list_tensorboards"] + + @property + def delete_tensorboard( + self, + ) -> Callable[[tensorboard_service.DeleteTensorboardRequest], operations.Operation]: + r"""Return a callable for the delete tensorboard method over gRPC. + + Deletes a Tensorboard. + + Returns: + Callable[[~.DeleteTensorboardRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard" not in self._stubs: + self._stubs["delete_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboard", + request_serializer=tensorboard_service.DeleteTensorboardRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard"] + + @property + def create_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.CreateTensorboardExperimentRequest], + gca_tensorboard_experiment.TensorboardExperiment, + ]: + r"""Return a callable for the create tensorboard experiment method over gRPC. + + Creates a TensorboardExperiment. + + Returns: + Callable[[~.CreateTensorboardExperimentRequest], + ~.TensorboardExperiment]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard_experiment" not in self._stubs: + self._stubs[ + "create_tensorboard_experiment" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboardExperiment", + request_serializer=tensorboard_service.CreateTensorboardExperimentRequest.serialize, + response_deserializer=gca_tensorboard_experiment.TensorboardExperiment.deserialize, + ) + return self._stubs["create_tensorboard_experiment"] + + @property + def get_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.GetTensorboardExperimentRequest], + tensorboard_experiment.TensorboardExperiment, + ]: + r"""Return a callable for the get tensorboard experiment method over gRPC. + + Gets a TensorboardExperiment. + + Returns: + Callable[[~.GetTensorboardExperimentRequest], + ~.TensorboardExperiment]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard_experiment" not in self._stubs: + self._stubs["get_tensorboard_experiment"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboardExperiment", + request_serializer=tensorboard_service.GetTensorboardExperimentRequest.serialize, + response_deserializer=tensorboard_experiment.TensorboardExperiment.deserialize, + ) + return self._stubs["get_tensorboard_experiment"] + + @property + def update_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.UpdateTensorboardExperimentRequest], + gca_tensorboard_experiment.TensorboardExperiment, + ]: + r"""Return a callable for the update tensorboard experiment method over gRPC. + + Updates a TensorboardExperiment. + + Returns: + Callable[[~.UpdateTensorboardExperimentRequest], + ~.TensorboardExperiment]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard_experiment" not in self._stubs: + self._stubs[ + "update_tensorboard_experiment" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboardExperiment", + request_serializer=tensorboard_service.UpdateTensorboardExperimentRequest.serialize, + response_deserializer=gca_tensorboard_experiment.TensorboardExperiment.deserialize, + ) + return self._stubs["update_tensorboard_experiment"] + + @property + def list_tensorboard_experiments( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardExperimentsRequest], + tensorboard_service.ListTensorboardExperimentsResponse, + ]: + r"""Return a callable for the list tensorboard experiments method over gRPC. + + Lists TensorboardExperiments in a Location. + + Returns: + Callable[[~.ListTensorboardExperimentsRequest], + ~.ListTensorboardExperimentsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboard_experiments" not in self._stubs: + self._stubs["list_tensorboard_experiments"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboardExperiments", + request_serializer=tensorboard_service.ListTensorboardExperimentsRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardExperimentsResponse.deserialize, + ) + return self._stubs["list_tensorboard_experiments"] + + @property + def delete_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.DeleteTensorboardExperimentRequest], operations.Operation + ]: + r"""Return a callable for the delete tensorboard experiment method over gRPC. + + Deletes a TensorboardExperiment. + + Returns: + Callable[[~.DeleteTensorboardExperimentRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard_experiment" not in self._stubs: + self._stubs[ + "delete_tensorboard_experiment" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboardExperiment", + request_serializer=tensorboard_service.DeleteTensorboardExperimentRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard_experiment"] + + @property + def create_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.CreateTensorboardRunRequest], + gca_tensorboard_run.TensorboardRun, + ]: + r"""Return a callable for the create tensorboard run method over gRPC. + + Creates a TensorboardRun. + + Returns: + Callable[[~.CreateTensorboardRunRequest], + ~.TensorboardRun]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard_run" not in self._stubs: + self._stubs["create_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboardRun", + request_serializer=tensorboard_service.CreateTensorboardRunRequest.serialize, + response_deserializer=gca_tensorboard_run.TensorboardRun.deserialize, + ) + return self._stubs["create_tensorboard_run"] + + @property + def get_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.GetTensorboardRunRequest], tensorboard_run.TensorboardRun + ]: + r"""Return a callable for the get tensorboard run method over gRPC. + + Gets a TensorboardRun. + + Returns: + Callable[[~.GetTensorboardRunRequest], + ~.TensorboardRun]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard_run" not in self._stubs: + self._stubs["get_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboardRun", + request_serializer=tensorboard_service.GetTensorboardRunRequest.serialize, + response_deserializer=tensorboard_run.TensorboardRun.deserialize, + ) + return self._stubs["get_tensorboard_run"] + + @property + def update_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.UpdateTensorboardRunRequest], + gca_tensorboard_run.TensorboardRun, + ]: + r"""Return a callable for the update tensorboard run method over gRPC. + + Updates a TensorboardRun. + + Returns: + Callable[[~.UpdateTensorboardRunRequest], + ~.TensorboardRun]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard_run" not in self._stubs: + self._stubs["update_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboardRun", + request_serializer=tensorboard_service.UpdateTensorboardRunRequest.serialize, + response_deserializer=gca_tensorboard_run.TensorboardRun.deserialize, + ) + return self._stubs["update_tensorboard_run"] + + @property + def list_tensorboard_runs( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardRunsRequest], + tensorboard_service.ListTensorboardRunsResponse, + ]: + r"""Return a callable for the list tensorboard runs method over gRPC. + + Lists TensorboardRuns in a Location. + + Returns: + Callable[[~.ListTensorboardRunsRequest], + ~.ListTensorboardRunsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboard_runs" not in self._stubs: + self._stubs["list_tensorboard_runs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboardRuns", + request_serializer=tensorboard_service.ListTensorboardRunsRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardRunsResponse.deserialize, + ) + return self._stubs["list_tensorboard_runs"] + + @property + def delete_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.DeleteTensorboardRunRequest], operations.Operation + ]: + r"""Return a callable for the delete tensorboard run method over gRPC. + + Deletes a TensorboardRun. + + Returns: + Callable[[~.DeleteTensorboardRunRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard_run" not in self._stubs: + self._stubs["delete_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboardRun", + request_serializer=tensorboard_service.DeleteTensorboardRunRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard_run"] + + @property + def create_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.CreateTensorboardTimeSeriesRequest], + gca_tensorboard_time_series.TensorboardTimeSeries, + ]: + r"""Return a callable for the create tensorboard time series method over gRPC. + + Creates a TensorboardTimeSeries. + + Returns: + Callable[[~.CreateTensorboardTimeSeriesRequest], + ~.TensorboardTimeSeries]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard_time_series" not in self._stubs: + self._stubs[ + "create_tensorboard_time_series" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboardTimeSeries", + request_serializer=tensorboard_service.CreateTensorboardTimeSeriesRequest.serialize, + response_deserializer=gca_tensorboard_time_series.TensorboardTimeSeries.deserialize, + ) + return self._stubs["create_tensorboard_time_series"] + + @property + def get_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.GetTensorboardTimeSeriesRequest], + tensorboard_time_series.TensorboardTimeSeries, + ]: + r"""Return a callable for the get tensorboard time series method over gRPC. + + Gets a TensorboardTimeSeries. + + Returns: + Callable[[~.GetTensorboardTimeSeriesRequest], + ~.TensorboardTimeSeries]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard_time_series" not in self._stubs: + self._stubs["get_tensorboard_time_series"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboardTimeSeries", + request_serializer=tensorboard_service.GetTensorboardTimeSeriesRequest.serialize, + response_deserializer=tensorboard_time_series.TensorboardTimeSeries.deserialize, + ) + return self._stubs["get_tensorboard_time_series"] + + @property + def update_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.UpdateTensorboardTimeSeriesRequest], + gca_tensorboard_time_series.TensorboardTimeSeries, + ]: + r"""Return a callable for the update tensorboard time series method over gRPC. + + Updates a TensorboardTimeSeries. + + Returns: + Callable[[~.UpdateTensorboardTimeSeriesRequest], + ~.TensorboardTimeSeries]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard_time_series" not in self._stubs: + self._stubs[ + "update_tensorboard_time_series" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboardTimeSeries", + request_serializer=tensorboard_service.UpdateTensorboardTimeSeriesRequest.serialize, + response_deserializer=gca_tensorboard_time_series.TensorboardTimeSeries.deserialize, + ) + return self._stubs["update_tensorboard_time_series"] + + @property + def list_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardTimeSeriesRequest], + tensorboard_service.ListTensorboardTimeSeriesResponse, + ]: + r"""Return a callable for the list tensorboard time series method over gRPC. + + Lists TensorboardTimeSeries in a Location. + + Returns: + Callable[[~.ListTensorboardTimeSeriesRequest], + ~.ListTensorboardTimeSeriesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboard_time_series" not in self._stubs: + self._stubs["list_tensorboard_time_series"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboardTimeSeries", + request_serializer=tensorboard_service.ListTensorboardTimeSeriesRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardTimeSeriesResponse.deserialize, + ) + return self._stubs["list_tensorboard_time_series"] + + @property + def delete_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.DeleteTensorboardTimeSeriesRequest], operations.Operation + ]: + r"""Return a callable for the delete tensorboard time series method over gRPC. + + Deletes a TensorboardTimeSeries. + + Returns: + Callable[[~.DeleteTensorboardTimeSeriesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard_time_series" not in self._stubs: + self._stubs[ + "delete_tensorboard_time_series" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboardTimeSeries", + request_serializer=tensorboard_service.DeleteTensorboardTimeSeriesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard_time_series"] + + @property + def read_tensorboard_time_series_data( + self, + ) -> Callable[ + [tensorboard_service.ReadTensorboardTimeSeriesDataRequest], + tensorboard_service.ReadTensorboardTimeSeriesDataResponse, + ]: + r"""Return a callable for the read tensorboard time series + data method over gRPC. + + Reads a TensorboardTimeSeries' data. Data is returned in + paginated responses. By default, if the number of data points + stored is less than 1000, all data will be returned. Otherwise, + 1000 data points will be randomly selected from this time series + and returned. This value can be changed by changing + max_data_points. + + Returns: + Callable[[~.ReadTensorboardTimeSeriesDataRequest], + ~.ReadTensorboardTimeSeriesDataResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "read_tensorboard_time_series_data" not in self._stubs: + self._stubs[ + "read_tensorboard_time_series_data" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ReadTensorboardTimeSeriesData", + request_serializer=tensorboard_service.ReadTensorboardTimeSeriesDataRequest.serialize, + response_deserializer=tensorboard_service.ReadTensorboardTimeSeriesDataResponse.deserialize, + ) + return self._stubs["read_tensorboard_time_series_data"] + + @property + def read_tensorboard_blob_data( + self, + ) -> Callable[ + [tensorboard_service.ReadTensorboardBlobDataRequest], + tensorboard_service.ReadTensorboardBlobDataResponse, + ]: + r"""Return a callable for the read tensorboard blob data method over gRPC. + + Gets bytes of TensorboardBlobs. + This is to allow reading blob data stored in consumer + project's Cloud Storage bucket without users having to + obtain Cloud Storage access permission. + + Returns: + Callable[[~.ReadTensorboardBlobDataRequest], + ~.ReadTensorboardBlobDataResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "read_tensorboard_blob_data" not in self._stubs: + self._stubs["read_tensorboard_blob_data"] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ReadTensorboardBlobData", + request_serializer=tensorboard_service.ReadTensorboardBlobDataRequest.serialize, + response_deserializer=tensorboard_service.ReadTensorboardBlobDataResponse.deserialize, + ) + return self._stubs["read_tensorboard_blob_data"] + + @property + def write_tensorboard_run_data( + self, + ) -> Callable[ + [tensorboard_service.WriteTensorboardRunDataRequest], + tensorboard_service.WriteTensorboardRunDataResponse, + ]: + r"""Return a callable for the write tensorboard run data method over gRPC. + + Write time series data points into multiple + TensorboardTimeSeries under a TensorboardRun. If any + data fail to be ingested, an error will be returned. + + Returns: + Callable[[~.WriteTensorboardRunDataRequest], + ~.WriteTensorboardRunDataResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "write_tensorboard_run_data" not in self._stubs: + self._stubs["write_tensorboard_run_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/WriteTensorboardRunData", + request_serializer=tensorboard_service.WriteTensorboardRunDataRequest.serialize, + response_deserializer=tensorboard_service.WriteTensorboardRunDataResponse.deserialize, + ) + return self._stubs["write_tensorboard_run_data"] + + @property + def export_tensorboard_time_series_data( + self, + ) -> Callable[ + [tensorboard_service.ExportTensorboardTimeSeriesDataRequest], + tensorboard_service.ExportTensorboardTimeSeriesDataResponse, + ]: + r"""Return a callable for the export tensorboard time series + data method over gRPC. + + Exports a TensorboardTimeSeries' data. Data is + returned in paginated responses. + + Returns: + Callable[[~.ExportTensorboardTimeSeriesDataRequest], + ~.ExportTensorboardTimeSeriesDataResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_tensorboard_time_series_data" not in self._stubs: + self._stubs[ + "export_tensorboard_time_series_data" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ExportTensorboardTimeSeriesData", + request_serializer=tensorboard_service.ExportTensorboardTimeSeriesDataRequest.serialize, + response_deserializer=tensorboard_service.ExportTensorboardTimeSeriesDataResponse.deserialize, + ) + return self._stubs["export_tensorboard_time_series_data"] + + +__all__ = ("TensorboardServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..d49895cdad --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py @@ -0,0 +1,980 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_experiment as gca_tensorboard_experiment, +) +from google.cloud.aiplatform_v1beta1.types import tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_service +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_time_series as gca_tensorboard_time_series, +) +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import TensorboardServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import TensorboardServiceGrpcTransport + + +class TensorboardServiceGrpcAsyncIOTransport(TensorboardServiceTransport): + """gRPC AsyncIO backend transport for TensorboardService. + + TensorboardService + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_tensorboard( + self, + ) -> Callable[ + [tensorboard_service.CreateTensorboardRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the create tensorboard method over gRPC. + + Creates a Tensorboard. + + Returns: + Callable[[~.CreateTensorboardRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard" not in self._stubs: + self._stubs["create_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboard", + request_serializer=tensorboard_service.CreateTensorboardRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_tensorboard"] + + @property + def get_tensorboard( + self, + ) -> Callable[ + [tensorboard_service.GetTensorboardRequest], Awaitable[tensorboard.Tensorboard] + ]: + r"""Return a callable for the get tensorboard method over gRPC. + + Gets a Tensorboard. + + Returns: + Callable[[~.GetTensorboardRequest], + Awaitable[~.Tensorboard]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard" not in self._stubs: + self._stubs["get_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboard", + request_serializer=tensorboard_service.GetTensorboardRequest.serialize, + response_deserializer=tensorboard.Tensorboard.deserialize, + ) + return self._stubs["get_tensorboard"] + + @property + def update_tensorboard( + self, + ) -> Callable[ + [tensorboard_service.UpdateTensorboardRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the update tensorboard method over gRPC. + + Updates a Tensorboard. + + Returns: + Callable[[~.UpdateTensorboardRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard" not in self._stubs: + self._stubs["update_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboard", + request_serializer=tensorboard_service.UpdateTensorboardRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["update_tensorboard"] + + @property + def list_tensorboards( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardsRequest], + Awaitable[tensorboard_service.ListTensorboardsResponse], + ]: + r"""Return a callable for the list tensorboards method over gRPC. + + Lists Tensorboards in a Location. + + Returns: + Callable[[~.ListTensorboardsRequest], + Awaitable[~.ListTensorboardsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboards" not in self._stubs: + self._stubs["list_tensorboards"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboards", + request_serializer=tensorboard_service.ListTensorboardsRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardsResponse.deserialize, + ) + return self._stubs["list_tensorboards"] + + @property + def delete_tensorboard( + self, + ) -> Callable[ + [tensorboard_service.DeleteTensorboardRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete tensorboard method over gRPC. + + Deletes a Tensorboard. + + Returns: + Callable[[~.DeleteTensorboardRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard" not in self._stubs: + self._stubs["delete_tensorboard"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboard", + request_serializer=tensorboard_service.DeleteTensorboardRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard"] + + @property + def create_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.CreateTensorboardExperimentRequest], + Awaitable[gca_tensorboard_experiment.TensorboardExperiment], + ]: + r"""Return a callable for the create tensorboard experiment method over gRPC. + + Creates a TensorboardExperiment. + + Returns: + Callable[[~.CreateTensorboardExperimentRequest], + Awaitable[~.TensorboardExperiment]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard_experiment" not in self._stubs: + self._stubs[ + "create_tensorboard_experiment" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboardExperiment", + request_serializer=tensorboard_service.CreateTensorboardExperimentRequest.serialize, + response_deserializer=gca_tensorboard_experiment.TensorboardExperiment.deserialize, + ) + return self._stubs["create_tensorboard_experiment"] + + @property + def get_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.GetTensorboardExperimentRequest], + Awaitable[tensorboard_experiment.TensorboardExperiment], + ]: + r"""Return a callable for the get tensorboard experiment method over gRPC. + + Gets a TensorboardExperiment. + + Returns: + Callable[[~.GetTensorboardExperimentRequest], + Awaitable[~.TensorboardExperiment]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard_experiment" not in self._stubs: + self._stubs["get_tensorboard_experiment"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboardExperiment", + request_serializer=tensorboard_service.GetTensorboardExperimentRequest.serialize, + response_deserializer=tensorboard_experiment.TensorboardExperiment.deserialize, + ) + return self._stubs["get_tensorboard_experiment"] + + @property + def update_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.UpdateTensorboardExperimentRequest], + Awaitable[gca_tensorboard_experiment.TensorboardExperiment], + ]: + r"""Return a callable for the update tensorboard experiment method over gRPC. + + Updates a TensorboardExperiment. + + Returns: + Callable[[~.UpdateTensorboardExperimentRequest], + Awaitable[~.TensorboardExperiment]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard_experiment" not in self._stubs: + self._stubs[ + "update_tensorboard_experiment" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboardExperiment", + request_serializer=tensorboard_service.UpdateTensorboardExperimentRequest.serialize, + response_deserializer=gca_tensorboard_experiment.TensorboardExperiment.deserialize, + ) + return self._stubs["update_tensorboard_experiment"] + + @property + def list_tensorboard_experiments( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardExperimentsRequest], + Awaitable[tensorboard_service.ListTensorboardExperimentsResponse], + ]: + r"""Return a callable for the list tensorboard experiments method over gRPC. + + Lists TensorboardExperiments in a Location. + + Returns: + Callable[[~.ListTensorboardExperimentsRequest], + Awaitable[~.ListTensorboardExperimentsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboard_experiments" not in self._stubs: + self._stubs["list_tensorboard_experiments"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboardExperiments", + request_serializer=tensorboard_service.ListTensorboardExperimentsRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardExperimentsResponse.deserialize, + ) + return self._stubs["list_tensorboard_experiments"] + + @property + def delete_tensorboard_experiment( + self, + ) -> Callable[ + [tensorboard_service.DeleteTensorboardExperimentRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete tensorboard experiment method over gRPC. + + Deletes a TensorboardExperiment. + + Returns: + Callable[[~.DeleteTensorboardExperimentRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard_experiment" not in self._stubs: + self._stubs[ + "delete_tensorboard_experiment" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboardExperiment", + request_serializer=tensorboard_service.DeleteTensorboardExperimentRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard_experiment"] + + @property + def create_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.CreateTensorboardRunRequest], + Awaitable[gca_tensorboard_run.TensorboardRun], + ]: + r"""Return a callable for the create tensorboard run method over gRPC. + + Creates a TensorboardRun. + + Returns: + Callable[[~.CreateTensorboardRunRequest], + Awaitable[~.TensorboardRun]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard_run" not in self._stubs: + self._stubs["create_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboardRun", + request_serializer=tensorboard_service.CreateTensorboardRunRequest.serialize, + response_deserializer=gca_tensorboard_run.TensorboardRun.deserialize, + ) + return self._stubs["create_tensorboard_run"] + + @property + def get_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.GetTensorboardRunRequest], + Awaitable[tensorboard_run.TensorboardRun], + ]: + r"""Return a callable for the get tensorboard run method over gRPC. + + Gets a TensorboardRun. + + Returns: + Callable[[~.GetTensorboardRunRequest], + Awaitable[~.TensorboardRun]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard_run" not in self._stubs: + self._stubs["get_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboardRun", + request_serializer=tensorboard_service.GetTensorboardRunRequest.serialize, + response_deserializer=tensorboard_run.TensorboardRun.deserialize, + ) + return self._stubs["get_tensorboard_run"] + + @property + def update_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.UpdateTensorboardRunRequest], + Awaitable[gca_tensorboard_run.TensorboardRun], + ]: + r"""Return a callable for the update tensorboard run method over gRPC. + + Updates a TensorboardRun. + + Returns: + Callable[[~.UpdateTensorboardRunRequest], + Awaitable[~.TensorboardRun]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard_run" not in self._stubs: + self._stubs["update_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboardRun", + request_serializer=tensorboard_service.UpdateTensorboardRunRequest.serialize, + response_deserializer=gca_tensorboard_run.TensorboardRun.deserialize, + ) + return self._stubs["update_tensorboard_run"] + + @property + def list_tensorboard_runs( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardRunsRequest], + Awaitable[tensorboard_service.ListTensorboardRunsResponse], + ]: + r"""Return a callable for the list tensorboard runs method over gRPC. + + Lists TensorboardRuns in a Location. + + Returns: + Callable[[~.ListTensorboardRunsRequest], + Awaitable[~.ListTensorboardRunsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboard_runs" not in self._stubs: + self._stubs["list_tensorboard_runs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboardRuns", + request_serializer=tensorboard_service.ListTensorboardRunsRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardRunsResponse.deserialize, + ) + return self._stubs["list_tensorboard_runs"] + + @property + def delete_tensorboard_run( + self, + ) -> Callable[ + [tensorboard_service.DeleteTensorboardRunRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete tensorboard run method over gRPC. + + Deletes a TensorboardRun. + + Returns: + Callable[[~.DeleteTensorboardRunRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard_run" not in self._stubs: + self._stubs["delete_tensorboard_run"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboardRun", + request_serializer=tensorboard_service.DeleteTensorboardRunRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard_run"] + + @property + def create_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.CreateTensorboardTimeSeriesRequest], + Awaitable[gca_tensorboard_time_series.TensorboardTimeSeries], + ]: + r"""Return a callable for the create tensorboard time series method over gRPC. + + Creates a TensorboardTimeSeries. + + Returns: + Callable[[~.CreateTensorboardTimeSeriesRequest], + Awaitable[~.TensorboardTimeSeries]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tensorboard_time_series" not in self._stubs: + self._stubs[ + "create_tensorboard_time_series" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/CreateTensorboardTimeSeries", + request_serializer=tensorboard_service.CreateTensorboardTimeSeriesRequest.serialize, + response_deserializer=gca_tensorboard_time_series.TensorboardTimeSeries.deserialize, + ) + return self._stubs["create_tensorboard_time_series"] + + @property + def get_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.GetTensorboardTimeSeriesRequest], + Awaitable[tensorboard_time_series.TensorboardTimeSeries], + ]: + r"""Return a callable for the get tensorboard time series method over gRPC. + + Gets a TensorboardTimeSeries. + + Returns: + Callable[[~.GetTensorboardTimeSeriesRequest], + Awaitable[~.TensorboardTimeSeries]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tensorboard_time_series" not in self._stubs: + self._stubs["get_tensorboard_time_series"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/GetTensorboardTimeSeries", + request_serializer=tensorboard_service.GetTensorboardTimeSeriesRequest.serialize, + response_deserializer=tensorboard_time_series.TensorboardTimeSeries.deserialize, + ) + return self._stubs["get_tensorboard_time_series"] + + @property + def update_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.UpdateTensorboardTimeSeriesRequest], + Awaitable[gca_tensorboard_time_series.TensorboardTimeSeries], + ]: + r"""Return a callable for the update tensorboard time series method over gRPC. + + Updates a TensorboardTimeSeries. + + Returns: + Callable[[~.UpdateTensorboardTimeSeriesRequest], + Awaitable[~.TensorboardTimeSeries]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tensorboard_time_series" not in self._stubs: + self._stubs[ + "update_tensorboard_time_series" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/UpdateTensorboardTimeSeries", + request_serializer=tensorboard_service.UpdateTensorboardTimeSeriesRequest.serialize, + response_deserializer=gca_tensorboard_time_series.TensorboardTimeSeries.deserialize, + ) + return self._stubs["update_tensorboard_time_series"] + + @property + def list_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.ListTensorboardTimeSeriesRequest], + Awaitable[tensorboard_service.ListTensorboardTimeSeriesResponse], + ]: + r"""Return a callable for the list tensorboard time series method over gRPC. + + Lists TensorboardTimeSeries in a Location. + + Returns: + Callable[[~.ListTensorboardTimeSeriesRequest], + Awaitable[~.ListTensorboardTimeSeriesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tensorboard_time_series" not in self._stubs: + self._stubs["list_tensorboard_time_series"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ListTensorboardTimeSeries", + request_serializer=tensorboard_service.ListTensorboardTimeSeriesRequest.serialize, + response_deserializer=tensorboard_service.ListTensorboardTimeSeriesResponse.deserialize, + ) + return self._stubs["list_tensorboard_time_series"] + + @property + def delete_tensorboard_time_series( + self, + ) -> Callable[ + [tensorboard_service.DeleteTensorboardTimeSeriesRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete tensorboard time series method over gRPC. + + Deletes a TensorboardTimeSeries. + + Returns: + Callable[[~.DeleteTensorboardTimeSeriesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tensorboard_time_series" not in self._stubs: + self._stubs[ + "delete_tensorboard_time_series" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/DeleteTensorboardTimeSeries", + request_serializer=tensorboard_service.DeleteTensorboardTimeSeriesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_tensorboard_time_series"] + + @property + def read_tensorboard_time_series_data( + self, + ) -> Callable[ + [tensorboard_service.ReadTensorboardTimeSeriesDataRequest], + Awaitable[tensorboard_service.ReadTensorboardTimeSeriesDataResponse], + ]: + r"""Return a callable for the read tensorboard time series + data method over gRPC. + + Reads a TensorboardTimeSeries' data. Data is returned in + paginated responses. By default, if the number of data points + stored is less than 1000, all data will be returned. Otherwise, + 1000 data points will be randomly selected from this time series + and returned. This value can be changed by changing + max_data_points. + + Returns: + Callable[[~.ReadTensorboardTimeSeriesDataRequest], + Awaitable[~.ReadTensorboardTimeSeriesDataResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "read_tensorboard_time_series_data" not in self._stubs: + self._stubs[ + "read_tensorboard_time_series_data" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ReadTensorboardTimeSeriesData", + request_serializer=tensorboard_service.ReadTensorboardTimeSeriesDataRequest.serialize, + response_deserializer=tensorboard_service.ReadTensorboardTimeSeriesDataResponse.deserialize, + ) + return self._stubs["read_tensorboard_time_series_data"] + + @property + def read_tensorboard_blob_data( + self, + ) -> Callable[ + [tensorboard_service.ReadTensorboardBlobDataRequest], + Awaitable[tensorboard_service.ReadTensorboardBlobDataResponse], + ]: + r"""Return a callable for the read tensorboard blob data method over gRPC. + + Gets bytes of TensorboardBlobs. + This is to allow reading blob data stored in consumer + project's Cloud Storage bucket without users having to + obtain Cloud Storage access permission. + + Returns: + Callable[[~.ReadTensorboardBlobDataRequest], + Awaitable[~.ReadTensorboardBlobDataResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "read_tensorboard_blob_data" not in self._stubs: + self._stubs["read_tensorboard_blob_data"] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ReadTensorboardBlobData", + request_serializer=tensorboard_service.ReadTensorboardBlobDataRequest.serialize, + response_deserializer=tensorboard_service.ReadTensorboardBlobDataResponse.deserialize, + ) + return self._stubs["read_tensorboard_blob_data"] + + @property + def write_tensorboard_run_data( + self, + ) -> Callable[ + [tensorboard_service.WriteTensorboardRunDataRequest], + Awaitable[tensorboard_service.WriteTensorboardRunDataResponse], + ]: + r"""Return a callable for the write tensorboard run data method over gRPC. + + Write time series data points into multiple + TensorboardTimeSeries under a TensorboardRun. If any + data fail to be ingested, an error will be returned. + + Returns: + Callable[[~.WriteTensorboardRunDataRequest], + Awaitable[~.WriteTensorboardRunDataResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "write_tensorboard_run_data" not in self._stubs: + self._stubs["write_tensorboard_run_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/WriteTensorboardRunData", + request_serializer=tensorboard_service.WriteTensorboardRunDataRequest.serialize, + response_deserializer=tensorboard_service.WriteTensorboardRunDataResponse.deserialize, + ) + return self._stubs["write_tensorboard_run_data"] + + @property + def export_tensorboard_time_series_data( + self, + ) -> Callable[ + [tensorboard_service.ExportTensorboardTimeSeriesDataRequest], + Awaitable[tensorboard_service.ExportTensorboardTimeSeriesDataResponse], + ]: + r"""Return a callable for the export tensorboard time series + data method over gRPC. + + Exports a TensorboardTimeSeries' data. Data is + returned in paginated responses. + + Returns: + Callable[[~.ExportTensorboardTimeSeriesDataRequest], + Awaitable[~.ExportTensorboardTimeSeriesDataResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_tensorboard_time_series_data" not in self._stubs: + self._stubs[ + "export_tensorboard_time_series_data" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.TensorboardService/ExportTensorboardTimeSeriesData", + request_serializer=tensorboard_service.ExportTensorboardTimeSeriesDataRequest.serialize, + response_deserializer=tensorboard_service.ExportTensorboardTimeSeriesDataResponse.deserialize, + ) + return self._stubs["export_tensorboard_time_series_data"] + + +__all__ = ("TensorboardServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index fb550654e8..0b02ac1777 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -115,7 +115,6 @@ FeatureValueList, ReadFeatureValuesRequest, ReadFeatureValuesResponse, - ReadSetting, StreamingReadFeatureValuesRequest, ) from .featurestore_service import ( @@ -135,6 +134,9 @@ DeleteFeatureRequest, DeleteFeaturestoreRequest, DestinationFeatureSetting, + ExportFeatureValuesOperationMetadata, + ExportFeatureValuesRequest, + ExportFeatureValuesResponse, FeatureValueDestination, GetEntityTypeRequest, GetFeatureRequest, @@ -343,11 +345,23 @@ DeleteOperationMetadata, GenericOperationMetadata, ) +from .pipeline_job import ( + PipelineJob, + PipelineJobDetail, + PipelineTaskDetail, + PipelineTaskExecutorDetail, +) from .pipeline_service import ( + CancelPipelineJobRequest, CancelTrainingPipelineRequest, + CreatePipelineJobRequest, CreateTrainingPipelineRequest, + DeletePipelineJobRequest, DeleteTrainingPipelineRequest, + GetPipelineJobRequest, GetTrainingPipelineRequest, + ListPipelineJobsRequest, + ListPipelineJobsResponse, ListTrainingPipelinesRequest, ListTrainingPipelinesResponse, ) @@ -374,6 +388,54 @@ StudySpec, Trial, ) +from .tensorboard import Tensorboard +from .tensorboard_data import ( + Scalar, + TensorboardBlob, + TensorboardBlobSequence, + TensorboardTensor, + TimeSeriesData, + TimeSeriesDataPoint, +) +from .tensorboard_experiment import TensorboardExperiment +from .tensorboard_run import TensorboardRun +from .tensorboard_service import ( + CreateTensorboardExperimentRequest, + CreateTensorboardOperationMetadata, + CreateTensorboardRequest, + CreateTensorboardRunRequest, + CreateTensorboardTimeSeriesRequest, + DeleteTensorboardExperimentRequest, + DeleteTensorboardRequest, + DeleteTensorboardRunRequest, + DeleteTensorboardTimeSeriesRequest, + ExportTensorboardTimeSeriesDataRequest, + ExportTensorboardTimeSeriesDataResponse, + GetTensorboardExperimentRequest, + GetTensorboardRequest, + GetTensorboardRunRequest, + GetTensorboardTimeSeriesRequest, + ListTensorboardExperimentsRequest, + ListTensorboardExperimentsResponse, + ListTensorboardRunsRequest, + ListTensorboardRunsResponse, + ListTensorboardsRequest, + ListTensorboardsResponse, + ListTensorboardTimeSeriesRequest, + ListTensorboardTimeSeriesResponse, + ReadTensorboardBlobDataRequest, + ReadTensorboardBlobDataResponse, + ReadTensorboardTimeSeriesDataRequest, + ReadTensorboardTimeSeriesDataResponse, + UpdateTensorboardExperimentRequest, + UpdateTensorboardOperationMetadata, + UpdateTensorboardRequest, + UpdateTensorboardRunRequest, + UpdateTensorboardTimeSeriesRequest, + WriteTensorboardRunDataRequest, + WriteTensorboardRunDataResponse, +) +from .tensorboard_time_series import TensorboardTimeSeries from .training_pipeline import ( FilterSplit, FractionSplit, @@ -389,6 +451,7 @@ StringArray, ) from .user_action_reference import UserActionReference +from .value import Value from .vizier_service import ( AddTrialMeasurementRequest, CheckTrialEarlyStoppingStateMetatdata, @@ -499,7 +562,6 @@ "FeatureValueList", "ReadFeatureValuesRequest", "ReadFeatureValuesResponse", - "ReadSetting", "StreamingReadFeatureValuesRequest", "BatchCreateFeaturesOperationMetadata", "BatchCreateFeaturesRequest", @@ -517,6 +579,9 @@ "DeleteFeatureRequest", "DeleteFeaturestoreRequest", "DestinationFeatureSetting", + "ExportFeatureValuesOperationMetadata", + "ExportFeatureValuesRequest", + "ExportFeatureValuesResponse", "FeatureValueDestination", "GetEntityTypeRequest", "GetFeatureRequest", @@ -699,10 +764,20 @@ "UploadModelResponse", "DeleteOperationMetadata", "GenericOperationMetadata", + "PipelineJob", + "PipelineJobDetail", + "PipelineTaskDetail", + "PipelineTaskExecutorDetail", + "CancelPipelineJobRequest", "CancelTrainingPipelineRequest", + "CreatePipelineJobRequest", "CreateTrainingPipelineRequest", + "DeletePipelineJobRequest", "DeleteTrainingPipelineRequest", + "GetPipelineJobRequest", "GetTrainingPipelineRequest", + "ListPipelineJobsRequest", + "ListPipelineJobsResponse", "ListTrainingPipelinesRequest", "ListTrainingPipelinesResponse", "PipelineState", @@ -723,6 +798,50 @@ "Study", "StudySpec", "Trial", + "Tensorboard", + "Scalar", + "TensorboardBlob", + "TensorboardBlobSequence", + "TensorboardTensor", + "TimeSeriesData", + "TimeSeriesDataPoint", + "TensorboardExperiment", + "TensorboardRun", + "CreateTensorboardExperimentRequest", + "CreateTensorboardOperationMetadata", + "CreateTensorboardRequest", + "CreateTensorboardRunRequest", + "CreateTensorboardTimeSeriesRequest", + "DeleteTensorboardExperimentRequest", + "DeleteTensorboardRequest", + "DeleteTensorboardRunRequest", + "DeleteTensorboardTimeSeriesRequest", + "ExportTensorboardTimeSeriesDataRequest", + "ExportTensorboardTimeSeriesDataResponse", + "GetTensorboardExperimentRequest", + "GetTensorboardRequest", + "GetTensorboardRunRequest", + "GetTensorboardTimeSeriesRequest", + "ListTensorboardExperimentsRequest", + "ListTensorboardExperimentsResponse", + "ListTensorboardRunsRequest", + "ListTensorboardRunsResponse", + "ListTensorboardsRequest", + "ListTensorboardsResponse", + "ListTensorboardTimeSeriesRequest", + "ListTensorboardTimeSeriesResponse", + "ReadTensorboardBlobDataRequest", + "ReadTensorboardBlobDataResponse", + "ReadTensorboardTimeSeriesDataRequest", + "ReadTensorboardTimeSeriesDataResponse", + "UpdateTensorboardExperimentRequest", + "UpdateTensorboardOperationMetadata", + "UpdateTensorboardRequest", + "UpdateTensorboardRunRequest", + "UpdateTensorboardTimeSeriesRequest", + "WriteTensorboardRunDataRequest", + "WriteTensorboardRunDataResponse", + "TensorboardTimeSeries", "FilterSplit", "FractionSplit", "InputDataConfig", @@ -734,6 +853,7 @@ "Int64Array", "StringArray", "UserActionReference", + "Value", "AddTrialMeasurementRequest", "CheckTrialEarlyStoppingStateMetatdata", "CheckTrialEarlyStoppingStateRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 1d148b7777..22e32f5996 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -175,6 +175,12 @@ class CustomJobSpec(proto.Message): ``//checkpoints/`` - AIP_TENSORBOARD_LOG_DIR = ``//logs/`` + tensorboard (str): + Optional. The name of an AI Platform + ``Tensorboard`` + resource to which this CustomJob will upload Tensorboard + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` """ worker_pool_specs = proto.RepeatedField( @@ -191,6 +197,8 @@ class CustomJobSpec(proto.Message): proto.MESSAGE, number=6, message=io.GcsDestination, ) + tensorboard = proto.Field(proto.STRING, number=7) + class WorkerPoolSpec(proto.Message): r"""Represents the spec of a worker pool in a job. diff --git a/google/cloud/aiplatform_v1beta1/types/entity_type.py b/google/cloud/aiplatform_v1beta1/types/entity_type.py index eabbe9190a..c1e599c569 100644 --- a/google/cloud/aiplatform_v1beta1/types/entity_type.py +++ b/google/cloud/aiplatform_v1beta1/types/entity_type.py @@ -70,6 +70,9 @@ class EntityType(proto.Message): odify-write updates. If not set, a blind "overwrite" update happens. monitoring_config (google.cloud.aiplatform_v1beta1.types.FeaturestoreMonitoringConfig): + Optional. The default monitoring configuration for all + Features under this EntityType. + If this is populated with [FeaturestoreMonitoringConfig.monitoring_interval] specified, snapshot analysis monitoring is enabled. diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 69947e9b9e..330d0a8ed5 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -242,7 +242,7 @@ class Visualization(proto.Message): clip_percent_lowerbound (float): Excludes attributions below the specified percentile, from the highlighted areas. Defaults - to 35. + to 62. overlay_type (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Visualization.OverlayType): How the original image is displayed in the visualization. Adjusting the overlay can help diff --git a/google/cloud/aiplatform_v1beta1/types/feature.py b/google/cloud/aiplatform_v1beta1/types/feature.py index eed5209479..4d0d0388b7 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature.py +++ b/google/cloud/aiplatform_v1beta1/types/feature.py @@ -72,6 +72,10 @@ class Feature(proto.Message): rite updates. If not set, a blind "overwrite" update happens. monitoring_config (google.cloud.aiplatform_v1beta1.types.FeaturestoreMonitoringConfig): + Optional. The custom monitoring configuration for this + Feature, if not set, use the monitoring_config defined for + the EntityType this Feature belongs to. + If this is populated with [FeaturestoreMonitoringConfig.disabled][] = true, snapshot analysis monitoring is disabled; if diff --git a/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py b/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py index a0c6d51e0f..af3d2ee034 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/feature_monitoring_stats.py @@ -76,11 +76,17 @@ class FeatureStatsAnomaly(proto.Message): different from ``ThresholdConfig.value``. start_time (google.protobuf.timestamp_pb2.Timestamp): - The start timestamp of window where stats - were generated. + The start timestamp of window where stats were generated. + For objectives where time window doesn't make sense (e.g. + Featurestore Snapshot Monitoring), start_time is only used + to indicate the monitoring intervals, so it always equals to + (end_time - monitoring_interval). end_time (google.protobuf.timestamp_pb2.Timestamp): - The end timestamp of window where stats were - generated. + The end timestamp of window where stats were generated. For + objectives where time window doesn't make sense (e.g. + Featurestore Snapshot Monitoring), end_time indicates the + timestamp of the data used to generate stats (e.g. timestamp + we take snapshots for feature values). """ score = proto.Field(proto.DOUBLE, number=1) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore.py b/google/cloud/aiplatform_v1beta1/types/featurestore.py index 378b651b42..670453f362 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore.py @@ -88,23 +88,10 @@ class OnlineServingConfig(proto.Message): cluster. The number of nodes will not scale automatically but can be scaled manually by providing different values when updating. - max_online_serving_size (int): - Maximum number of feature values per entity - that will be stored in online serving storage. - The Featurestore will retain the latest feature - values per entity and periodically remove any - older feature values. It can take up to a day - before the older feature values are deleted. - Storage infrastructure cost is propotional to - this value. Recommend to set to the largest - number of versions (i.e last-k) needed at online - serving time. If not set, default to 1. """ fixed_node_count = proto.Field(proto.INT32, number=2) - max_online_serving_size = proto.Field(proto.INT32, number=3) - name = proto.Field(proto.STRING, number=1) display_name = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py index 2ca2fe8dae..f7419243f0 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py @@ -29,7 +29,6 @@ package="google.cloud.aiplatform.v1beta1", manifest={ "ReadFeatureValuesRequest", - "ReadSetting", "ReadFeatureValuesResponse", "StreamingReadFeatureValuesRequest", "FeatureValue", @@ -56,17 +55,6 @@ class ReadFeatureValuesRequest(proto.Message): feature_selector (google.cloud.aiplatform_v1beta1.types.FeatureSelector): Required. Selector choosing Features of the target EntityType. - setting (google.cloud.aiplatform_v1beta1.types.ReadSetting): - Setting to apply to all Feature values being - read, by default. - setting_overrides (Sequence[google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesRequest.SettingOverridesEntry]): - Map from Feature ID to settings to apply to Feature values - being read. If no setting is specified for a Feature - selected by - ``ReadFeatureValuesRequest.feature_selector``, - the default - ``ReadFeatureValuesRequest.setting`` - will be used. """ entity_type = proto.Field(proto.STRING, number=1) @@ -77,33 +65,6 @@ class ReadFeatureValuesRequest(proto.Message): proto.MESSAGE, number=3, message=gca_feature_selector.FeatureSelector, ) - setting = proto.Field(proto.MESSAGE, number=5, message="ReadSetting",) - - setting_overrides = proto.MapField( - proto.STRING, proto.MESSAGE, number=6, message="ReadSetting", - ) - - -class ReadSetting(proto.Message): - r"""Setting to apply when reading Feature values, e.g. "limit - read to the K-latest values". - - Attributes: - values_count (int): - Number of values, successive in time, to - retrieve for a Feature. If not set, default to - 1. Must be less than or equal to 32. - read_time (google.protobuf.timestamp_pb2.Timestamp): - Retrieve latest values before or at this - timestamp. If not set, retrieve latest values. - Resolution in millisecond. Request will fail if - timestamp is not millisecond-aligned. - """ - - values_count = proto.Field(proto.INT32, number=2) - - read_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - class ReadFeatureValuesResponse(proto.Message): r"""Response message for @@ -223,17 +184,6 @@ class StreamingReadFeatureValuesRequest(proto.Message): feature_selector (google.cloud.aiplatform_v1beta1.types.FeatureSelector): Required. Selector choosing Features of the target EntityType. - setting (google.cloud.aiplatform_v1beta1.types.ReadSetting): - Setting to apply to all Feature values being - read, by default. - setting_overrides (Sequence[google.cloud.aiplatform_v1beta1.types.StreamingReadFeatureValuesRequest.SettingOverridesEntry]): - Map from Feature ID to settings to apply to Feature values - being read. If no setting is specified for a Feature - selected by - ``ReadFeatureValuesRequest.feature_selector``, - the default - ``ReadFeatureValuesRequest.setting`` - will be used. """ entity_type = proto.Field(proto.STRING, number=1) @@ -244,12 +194,6 @@ class StreamingReadFeatureValuesRequest(proto.Message): proto.MESSAGE, number=3, message=gca_feature_selector.FeatureSelector, ) - setting = proto.Field(proto.MESSAGE, number=5, message="ReadSetting",) - - setting_overrides = proto.MapField( - proto.STRING, proto.MESSAGE, number=6, message="ReadSetting", - ) - class FeatureValue(proto.Message): r"""Value for a feature. diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py index 6bf6c284b2..4a1a5213ec 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py @@ -42,8 +42,10 @@ "ImportFeatureValuesRequest", "ImportFeatureValuesResponse", "BatchReadFeatureValuesRequest", + "ExportFeatureValuesRequest", "DestinationFeatureSetting", "FeatureValueDestination", + "ExportFeatureValuesResponse", "BatchReadFeatureValuesResponse", "CreateEntityTypeRequest", "GetEntityTypeRequest", @@ -64,6 +66,7 @@ "CreateFeaturestoreOperationMetadata", "UpdateFeaturestoreOperationMetadata", "ImportFeatureValuesOperationMetadata", + "ExportFeatureValuesOperationMetadata", "BatchReadFeatureValuesOperationMetadata", "CreateEntityTypeOperationMetadata", "CreateFeatureOperationMetadata", @@ -300,13 +303,15 @@ class ImportFeatureValuesRequest(proto.Message): where Feature generation timestamps are not in the timestamp range needed for online serving. worker_count (int): - Required. Specifies the number of workers - that are used to write data to the Featurestore. - Consider the online serving capacity that you - require to achieve the desired import throughput - without interfering with online serving. The - value must be greater than 0, and less than or - equal to 100. + Specifies the number of workers that are used + to write data to the Featurestore. Consider the + online serving capacity that you require to + achieve the desired import throughput without + interfering with online serving. The value must + be positive, and less than or equal to 100. If + not set, defaults to using 1 worker. The low + count ensures minimal impact on online serving + performance. """ class FeatureSpec(proto.Message): @@ -465,6 +470,63 @@ class EntityTypeSpec(proto.Message): ) +class ExportFeatureValuesRequest(proto.Message): + r"""Request message for + ``FeaturestoreService.ExportFeatureValues``. + + Attributes: + snapshot_export (google.cloud.aiplatform_v1beta1.types.ExportFeatureValuesRequest.SnapshotExport): + Exports Feature values of all entities of the + EntityType as of a snapshot time. + entity_type (str): + Required. The resource name of the EntityType from which to + export Feature values. Format: + ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}`` + destination (google.cloud.aiplatform_v1beta1.types.FeatureValueDestination): + Required. Specifies destination location and + format. + feature_selector (google.cloud.aiplatform_v1beta1.types.FeatureSelector): + Required. Selects Features to export values + of. + settings (Sequence[google.cloud.aiplatform_v1beta1.types.DestinationFeatureSetting]): + Per-Feature export settings. + """ + + class SnapshotExport(proto.Message): + r"""Describes exporting Feature values as of the snapshot + timestamp. + + Attributes: + snapshot_time (google.protobuf.timestamp_pb2.Timestamp): + Exports Feature values as of this timestamp. + If not set, retrieve values as of now. + Timestamp, if present, must not have higher than + millisecond precision. + """ + + snapshot_time = proto.Field( + proto.MESSAGE, number=1, message=timestamp.Timestamp, + ) + + snapshot_export = proto.Field( + proto.MESSAGE, number=3, oneof="mode", message=SnapshotExport, + ) + + entity_type = proto.Field(proto.STRING, number=1) + + destination = proto.Field( + proto.MESSAGE, number=4, message="FeatureValueDestination", + ) + + feature_selector = proto.Field( + proto.MESSAGE, number=5, message=gca_feature_selector.FeatureSelector, + ) + + settings = proto.RepeatedField( + proto.MESSAGE, number=6, message="DestinationFeatureSetting", + ) + + class DestinationFeatureSetting(proto.Message): r""" @@ -525,6 +587,12 @@ class FeatureValueDestination(proto.Message): ) +class ExportFeatureValuesResponse(proto.Message): + r"""Response message for + ``FeaturestoreService.ExportFeatureValues``. + """ + + class BatchReadFeatureValuesResponse(proto.Message): r"""Response message for ``FeaturestoreService.BatchReadFeatureValues``. @@ -1142,6 +1210,20 @@ class ImportFeatureValuesOperationMetadata(proto.Message): imported_feature_value_count = proto.Field(proto.INT64, number=3) +class ExportFeatureValuesOperationMetadata(proto.Message): + r"""Details of operations that exports Features values. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Featurestore export + Feature values. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + class BatchReadFeatureValuesOperationMetadata(proto.Message): r"""Details of operations that batch reads Feature values. diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py index 28ce15cc75..b4cd467804 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py @@ -230,10 +230,17 @@ class AuthProvider(proto.Message): `audiences `__. that are allowed to access. A JWT containing any of these audiences will be accepted. + allowed_issuers (Sequence[str]): + A list of allowed JWT issuers. Each entry must be a valid + Google service account, in the following format: + + ``service-account-name@project-id.iam.gserviceaccount.com`` """ audiences = proto.RepeatedField(proto.STRING, number=1) + allowed_issuers = proto.RepeatedField(proto.STRING, number=2) + auth_provider = proto.Field(proto.MESSAGE, number=1, message=AuthProvider,) diff --git a/google/cloud/aiplatform_v1beta1/types/index_service.py b/google/cloud/aiplatform_v1beta1/types/index_service.py index 601f64c6e8..2e73b19d7c 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_service.py +++ b/google/cloud/aiplatform_v1beta1/types/index_service.py @@ -65,8 +65,8 @@ class CreateIndexOperationMetadata(proto.Message): generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. nearest_neighbor_search_operation_metadata (google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata): - The operation metadata with regard to ScaNN - Index operation. + The operation metadata with regard to + Matching Engine Index operation. """ generic_metadata = proto.Field( @@ -148,7 +148,8 @@ def raw_page(self): class UpdateIndexRequest(proto.Message): - r"""Request message for [IndexService.UpdateModel][]. + r"""Request message for + ``IndexService.UpdateIndex``. Attributes: index (google.cloud.aiplatform_v1beta1.types.Index): @@ -173,8 +174,8 @@ class UpdateIndexOperationMetadata(proto.Message): generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. nearest_neighbor_search_operation_metadata (google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata): - The operation metadata with regard to ScaNN - Index operation. + The operation metadata with regard to + Matching Engine Index operation. """ generic_metadata = proto.Field( @@ -201,13 +202,14 @@ class DeleteIndexRequest(proto.Message): class NearestNeighborSearchOperationMetadata(proto.Message): - r"""Runtime operation metadata with regard to ScaNN Index. + r"""Runtime operation metadata with regard to Matching Engine + Index. Attributes: content_validation_stats (Sequence[google.cloud.aiplatform_v1beta1.types.NearestNeighborSearchOperationMetadata.ContentValidationStats]): The validation stats of the content (per file) to be - inserted or updated on the ScaNN Index resource. Populated - if contentsDeltaUri is provided as part of + inserted or updated on the Matching Engine Index resource. + Populated if contentsDeltaUri is provided as part of ``Index.metadata``. Please note that, currently for those files that are broken or has unsupported file format, we will not have the stats diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index 7fd85c81b2..00922dce42 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -212,7 +212,7 @@ class CancelCustomJobRequest(proto.Message): class CreateDataLabelingJobRequest(proto.Message): r"""Request message for - [DataLabelingJobService.CreateDataLabelingJob][]. + ``JobService.CreateDataLabelingJob``. Attributes: parent (str): @@ -230,7 +230,8 @@ class CreateDataLabelingJobRequest(proto.Message): class GetDataLabelingJobRequest(proto.Message): - r"""Request message for [DataLabelingJobService.GetDataLabelingJob][]. + r"""Request message for + ``JobService.GetDataLabelingJob``. Attributes: name (str): @@ -242,7 +243,8 @@ class GetDataLabelingJobRequest(proto.Message): class ListDataLabelingJobsRequest(proto.Message): - r"""Request message for [DataLabelingJobService.ListDataLabelingJobs][]. + r"""Request message for + ``JobService.ListDataLabelingJobs``. Attributes: parent (str): @@ -334,7 +336,7 @@ class DeleteDataLabelingJobRequest(proto.Message): class CancelDataLabelingJobRequest(proto.Message): r"""Request message for - [DataLabelingJobService.CancelDataLabelingJob][]. + ``JobService.CancelDataLabelingJob``. Attributes: name (str): @@ -629,7 +631,7 @@ class CancelBatchPredictionJobRequest(proto.Message): class CreateModelDeploymentMonitoringJobRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.CreateModelDeploymentMonitoringJob][]. + ``JobService.CreateModelDeploymentMonitoringJob``. Attributes: parent (str): @@ -651,7 +653,7 @@ class CreateModelDeploymentMonitoringJobRequest(proto.Message): class SearchModelDeploymentMonitoringStatsAnomaliesRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies``. Attributes: model_deployment_monitoring_job (str): @@ -673,7 +675,7 @@ class SearchModelDeploymentMonitoringStatsAnomaliesRequest(proto.Message): The standard list page size. page_token (str): A page token received from a previous - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][] + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies`` call. start_time (google.protobuf.timestamp_pb2.Timestamp): The earliest timestamp of stats being @@ -731,17 +733,17 @@ class StatsAnomaliesObjective(proto.Message): class SearchModelDeploymentMonitoringStatsAnomaliesResponse(proto.Message): r"""Response message for - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][]. + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies``. Attributes: monitoring_stats (Sequence[google.cloud.aiplatform_v1beta1.types.ModelMonitoringStatsAnomalies]): Stats retrieved for requested objectives. There are at most 1000 - [ModelMonitoringStatsAnomalies.feature_stats.prediction_stats][] + ``ModelMonitoringStatsAnomalies.FeatureHistoricStatsAnomalies.prediction_stats`` in the response. next_page_token (str): The page token that can be used by the next - [ModelDeploymentMonitoringJobService.SearchModelDeploymentMonitoringStatsAnomalies][] + ``JobService.SearchModelDeploymentMonitoringStatsAnomalies`` call. """ @@ -760,7 +762,7 @@ def raw_page(self): class GetModelDeploymentMonitoringJobRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.GetModelDeploymentMonitoringJob][]. + ``JobService.GetModelDeploymentMonitoringJob``. Attributes: name (str): @@ -774,7 +776,7 @@ class GetModelDeploymentMonitoringJobRequest(proto.Message): class ListModelDeploymentMonitoringJobsRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + ``JobService.ListModelDeploymentMonitoringJobs``. Attributes: parent (str): @@ -803,7 +805,7 @@ class ListModelDeploymentMonitoringJobsRequest(proto.Message): class ListModelDeploymentMonitoringJobsResponse(proto.Message): r"""Response message for - [ModelDeploymentMonitoringJobService.ListModelDeploymentMonitoringJobs][]. + ``JobService.ListModelDeploymentMonitoringJobs``. Attributes: model_deployment_monitoring_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob]): @@ -828,7 +830,7 @@ def raw_page(self): class UpdateModelDeploymentMonitoringJobRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + ``JobService.UpdateModelDeploymentMonitoringJob``. Attributes: model_deployment_monitoring_job (google.cloud.aiplatform_v1beta1.types.ModelDeploymentMonitoringJob): @@ -850,7 +852,7 @@ class UpdateModelDeploymentMonitoringJobRequest(proto.Message): class DeleteModelDeploymentMonitoringJobRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.DeleteModelDeploymentMonitoringJob][]. + ``JobService.DeleteModelDeploymentMonitoringJob``. Attributes: name (str): @@ -864,7 +866,7 @@ class DeleteModelDeploymentMonitoringJobRequest(proto.Message): class PauseModelDeploymentMonitoringJobRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.PauseModelDeploymentMonitoringJob][]. + ``JobService.PauseModelDeploymentMonitoringJob``. Attributes: name (str): @@ -878,7 +880,7 @@ class PauseModelDeploymentMonitoringJobRequest(proto.Message): class ResumeModelDeploymentMonitoringJobRequest(proto.Message): r"""Request message for - [ModelDeploymentMonitoringJobService.ResumeModelDeploymentMonitoringJob][]. + ``JobService.ResumeModelDeploymentMonitoringJob``. Attributes: name (str): @@ -892,7 +894,7 @@ class ResumeModelDeploymentMonitoringJobRequest(proto.Message): class UpdateModelDeploymentMonitoringJobOperationMetadata(proto.Message): r"""Runtime operation information for - [ModelDeploymentMonitoringJobService.UpdateModelDeploymentMonitoringJob][]. + ``JobService.UpdateModelDeploymentMonitoringJob``. Attributes: generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_service.py b/google/cloud/aiplatform_v1beta1/types/metadata_service.py index 20d13257e7..afbfd5872d 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_service.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_service.py @@ -173,8 +173,8 @@ class ListMetadataStoresResponse(proto.Message): The MetadataStores found for the Location. next_page_token (str): A token, which can be sent as - [MetadataService.ListMetadataStores.page_token][] to - retrieve the next page. If this field is not populated, + ``ListMetadataStoresRequest.page_token`` + to retrieve the next page. If this field is not populated, there are no subsequent pages. """ @@ -291,8 +291,31 @@ class ListArtifactsRequest(proto.Message): the call that provided the page token. (Otherwise the request will fail with INVALID_ARGUMENT error.) filter (str): - A query to filter available Artifacts for - matching results. + Filter specifying the boolean condition for the Artifacts to + satisfy in order to be part of the result set. The syntax to + define filter query is based on https://google.aip.dev/160. + The supported set of filters include the following: + + 1. Attributes filtering e.g. display_name = "test" + + Supported fields include: name, display_name, uri, state, + schema_title, create_time and update_time. Time fields, + i.e. create_time and update_time, require values to + specified in RFC-3339 format. e.g. create_time = + "2020-11-19T11:30:00-04:00" + + 2. Metadata field To filter on metadata fields use traversal + operation as follows: metadata.. + e.g. metadata.field_1.number_value = 10.0 + + 3. Context based filtering To filter Artifacts based on the + contexts to which they belong use the function operator + with the full resource name "in_context()" e.g. + in_context("projects//locations//metadataStores//contexts/") + + Each of the above supported filter types can be combined + together using Logical operators (AND & OR). e.g. + display_name = "test" AND metadata.field1.bool_value = true. """ parent = proto.Field(proto.STRING, number=1) @@ -314,9 +337,9 @@ class ListArtifactsResponse(proto.Message): MetadataStore. next_page_token (str): A token, which can be sent as - [MetadataService.ListArtifacts.page_token][] to retrieve the - next page. If this field is not populated, there are no - subsequent pages. + ``ListArtifactsRequest.page_token`` + to retrieve the next page. If this field is not populated, + there are no subsequent pages. """ @property @@ -425,8 +448,7 @@ class ListContextsRequest(proto.Message): the call that provided the page token. (Otherwise the request will fail with INVALID_ARGUMENT error.) filter (str): - A query to filter available Contexts for - matching results. + """ parent = proto.Field(proto.STRING, number=1) @@ -448,9 +470,9 @@ class ListContextsResponse(proto.Message): MetadataStore. next_page_token (str): A token, which can be sent as - [MetadataService.ListContexts.page_token][] to retrieve the - next page. If this field is not populated, there are no - subsequent pages. + ``ListContextsRequest.page_token`` + to retrieve the next page. If this field is not populated, + there are no subsequent pages. """ @property @@ -654,14 +676,32 @@ class ListExecutionsRequest(proto.Message): the call that provided the page token. (Otherwise the request will fail with INVALID_ARGUMENT error.) filter (str): - A query to filter available Executions for matching results. - Current implementation supports filtering on fields: + Filter specifying the boolean condition for the Executions + to satisfy in order to be part of the result set. The syntax + to define filter query is based on + https://google.aip.dev/160. Following are the supported set + of filters: + + 1. Attributes filtering e.g. display_name = "test" + + supported fields include: name, display_name, state, + schema_title, create_time and update_time. Time fields, + i.e. create_time and update_time, require values to + specified in RFC-3339 format. e.g. create_time = + "2020-11-19T11:30:00-04:00" - 1) display_name e.g display_name = "test_name" - 2) state e.g. state = RUNNING - 3) create_time and update_time e.g create_time > - "2020-12-17T13:25:12-08:00" - 4) metadata e.g metadata.flag.number_value > 1 + 2. Metadata field To filter on metadata fields use traversal + operation as follows: metadata.. + e.g. metadata.field_1.number_value = 10.0 + + 3. Context based filtering To filter Executions based on the + contexts to which they belong use the function operator + with the full resource name "in_context()" e.g. + in_context("projects//locations//metadataStores//contexts/") + + Each of the above supported filters can be combined together + using Logical operators (AND & OR). e.g. display_name = + "test" AND metadata.field1.bool_value = true. """ parent = proto.Field(proto.STRING, number=1) @@ -683,9 +723,9 @@ class ListExecutionsResponse(proto.Message): MetadataStore. next_page_token (str): A token, which can be sent as - [MetadataService.ListExecutions.page_token][] to retrieve - the next page. If this field is not populated, there are no - subsequent pages. + ``ListExecutionsRequest.page_token`` + to retrieve the next page. If this field is not populated, + there are no subsequent pages. """ @property @@ -862,8 +902,8 @@ class ListMetadataSchemasResponse(proto.Message): MetadataStore. next_page_token (str): A token, which can be sent as - [MetadataService.ListMetadataSchemas.page_token][] to - retrieve the next page. If this field is not populated, + ``ListMetadataSchemasRequest.page_token`` + to retrieve the next page. If this field is not populated, there are no subsequent pages. """ @@ -897,11 +937,35 @@ class QueryArtifactLineageSubgraphRequest(proto.Message): INVALID_ARGUMENT error is returned 0: Only input artifact is returned. No value: Transitive closure is performed to return the complete graph. + filter (str): + Filter specifying the boolean condition for the Artifacts to + satisfy in order to be part of the Lineage Subgraph. The + syntax to define filter query is based on + https://google.aip.dev/160. The supported set of filters + include the following: + + 1. Attributes filtering e.g. display_name = "test" + + supported fields include: name, display_name, uri, state, + schema_title, create_time and update_time. Time fields, + i.e. create_time and update_time, require values to + specified in RFC-3339 format. e.g. create_time = + "2020-11-19T11:30:00-04:00" + + 2. Metadata field To filter on metadata fields use traversal + operation as follows: metadata.. + e.g. metadata.field_1.number_value = 10.0 + + Each of the above supported filter types can be combined + together using Logical operators (AND & OR). e.g. + display_name = "test" AND metadata.field1.bool_value = true. """ artifact = proto.Field(proto.STRING, number=1) max_hops = proto.Field(proto.INT32, number=2) + filter = proto.Field(proto.STRING, number=3) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_store.py b/google/cloud/aiplatform_v1beta1/types/metadata_store.py index 19456d92eb..b57c00573a 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_store.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_store.py @@ -46,8 +46,24 @@ class MetadataStore(proto.Message): Metadata Store. If set, this Metadata Store and all sub-resources of this Metadata Store will be secured by this key. + description (str): + Description of the MetadataStore. + state (google.cloud.aiplatform_v1beta1.types.MetadataStore.MetadataStoreState): + Output only. State information of the + MetadataStore. """ + class MetadataStoreState(proto.Message): + r"""Represent state information for a MetadataStore. + + Attributes: + disk_utilization_bytes (int): + The disk utilization of the MetadataStore in + bytes. + """ + + disk_utilization_bytes = proto.Field(proto.INT64, number=1) + name = proto.Field(proto.STRING, number=1) create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) @@ -58,5 +74,9 @@ class MetadataStore(proto.Message): proto.MESSAGE, number=5, message=gca_encryption_spec.EncryptionSpec, ) + description = proto.Field(proto.STRING, number=6) + + state = proto.Field(proto.MESSAGE, number=7, message=MetadataStoreState,) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py index de4c9466f6..d1f232a91f 100644 --- a/google/cloud/aiplatform_v1beta1/types/migration_service.py +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -56,21 +56,25 @@ class SearchMigratableResourcesRequest(proto.Message): page_token (str): The standard page token. filter (str): - Supported filters are: + A filter for your search. You can use the following types of + filters: - - Resource type: For a specific type of MigratableResource. + - Resource type filters. The following strings filter for a + specific type of + ``MigratableResource``: - ``ml_engine_model_version:*`` - - ``automl_model:*``, + - ``automl_model:*`` - ``automl_dataset:*`` - - ``data_labeling_dataset:*``. + - ``data_labeling_dataset:*`` - - Migrated or not: Filter migrated resource or not by - last_migrate_time. + - "Migrated or not" filters. The following strings filter + for resources that either have or have not already been + migrated: - - ``last_migrate_time:*`` will filter migrated + - ``last_migrate_time:*`` filters for migrated resources. - - ``NOT last_migrate_time:*`` will filter not yet + - ``NOT last_migrate_time:*`` filters for not yet migrated resources. """ diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_job.py b/google/cloud/aiplatform_v1beta1/types/pipeline_job.py new file mode 100644 index 0000000000..975b1ca798 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_job.py @@ -0,0 +1,353 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import context +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types import pipeline_state +from google.cloud.aiplatform_v1beta1.types import value as gca_value +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "PipelineJob", + "PipelineJobDetail", + "PipelineTaskDetail", + "PipelineTaskExecutorDetail", + }, +) + + +class PipelineJob(proto.Message): + r"""An instance of a machine learning PipelineJob. + + Attributes: + name (str): + Output only. The resource name of the + PipelineJob. + display_name (str): + The display name of the Pipeline. + The name can be up to 128 characters long and + can be consist of any UTF-8 characters. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Pipeline creation time. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Pipeline start time. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Pipeline end time. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this PipelineJob + was most recently updated. + pipeline_spec (google.protobuf.struct_pb2.Struct): + Required. The spec of the pipeline. The spec contains a + ``schema_version`` field which indicates the Kubeflow + Pipeline schema version to decode the struct. + state (google.cloud.aiplatform_v1beta1.types.PipelineState): + Output only. The detailed state of the job. + job_detail (google.cloud.aiplatform_v1beta1.types.PipelineJobDetail): + Output only. The details of pipeline run. Not + available in the list view. + error (google.rpc.status_pb2.Status): + Output only. The error that occurred during + pipeline execution. Only populated when the + pipeline's state is FAILED or CANCELLED. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.PipelineJob.LabelsEntry]): + The labels with user-defined metadata to + organize PipelineJob. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + runtime_config (google.cloud.aiplatform_v1beta1.types.PipelineJob.RuntimeConfig): + Runtime config of the pipeline. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for a + pipelineJob. If set, this PipelineJob and all of + its sub-resources will be secured by this key. + service_account (str): + The service account that the pipeline workload runs as. If + not specified, the Compute Engine default service account in + the project will be used. See + https://cloud.google.com/compute/docs/access/service-accounts#default_service_account + + Users starting the pipeline must have the + ``iam.serviceAccounts.actAs`` permission on this service + account. + network (str): + The full name of the Compute Engine + `network `__ + to which the Pipeline Job's workload should be peered. For + example, ``projects/12345/global/networks/myVPC``. + `Format `__ + is of the form + ``projects/{project}/global/networks/{network}``. Where + {project} is a project number, as in ``12345``, and + {network} is a network name. + + Private services access must already be configured for the + network. Pipeline job will apply the network configuration + to the GCP resources being launched, if applied, such as + Cloud AI Platform Training or Dataflow job. If left + unspecified, the workload is not peered with any network. + """ + + class RuntimeConfig(proto.Message): + r"""The runtime config of a PipelineJob. + + Attributes: + parameters (Sequence[google.cloud.aiplatform_v1beta1.types.PipelineJob.RuntimeConfig.ParametersEntry]): + The runtime parameters of the PipelineJob. The parameters + will be passed into + ``PipelineJob.pipeline_spec`` + to replace the placeholders at runtime. + gcs_output_directory (str): + Required. A path in a Cloud Storage bucket, which will be + treated as the root output directory of the pipeline. It is + used by the system to generate the paths of output + artifacts. The artifact paths are generated with a sub-path + pattern ``{job_id}/{task_id}/{output_key}`` under the + specified output directory. The service account specified in + this pipeline must have the ``storage.objects.get`` and + ``storage.objects.create`` permissions for this bucket. + """ + + parameters = proto.MapField( + proto.STRING, proto.MESSAGE, number=1, message=gca_value.Value, + ) + + gcs_output_directory = proto.Field(proto.STRING, number=2) + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + + start_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + pipeline_spec = proto.Field(proto.MESSAGE, number=7, message=struct.Struct,) + + state = proto.Field(proto.ENUM, number=8, enum=pipeline_state.PipelineState,) + + job_detail = proto.Field(proto.MESSAGE, number=9, message="PipelineJobDetail",) + + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=11) + + runtime_config = proto.Field(proto.MESSAGE, number=12, message=RuntimeConfig,) + + encryption_spec = proto.Field( + proto.MESSAGE, number=16, message=gca_encryption_spec.EncryptionSpec, + ) + + service_account = proto.Field(proto.STRING, number=17) + + network = proto.Field(proto.STRING, number=18) + + +class PipelineJobDetail(proto.Message): + r"""The runtime detail of PipelineJob. + + Attributes: + pipeline_context (google.cloud.aiplatform_v1beta1.types.Context): + Output only. The context of the pipeline. + pipeline_run_context (google.cloud.aiplatform_v1beta1.types.Context): + Output only. The context of the current + pipeline run. + task_details (Sequence[google.cloud.aiplatform_v1beta1.types.PipelineTaskDetail]): + Output only. The runtime details of the tasks + under the pipeline. + """ + + pipeline_context = proto.Field(proto.MESSAGE, number=1, message=context.Context,) + + pipeline_run_context = proto.Field( + proto.MESSAGE, number=2, message=context.Context, + ) + + task_details = proto.RepeatedField( + proto.MESSAGE, number=3, message="PipelineTaskDetail", + ) + + +class PipelineTaskDetail(proto.Message): + r"""The runtime detail of a task execution. + + Attributes: + task_id (int): + Output only. The system generated ID of the + task. + parent_task_id (int): + Output only. The id of the parent task if the + task is within a component scope. Empty if the + task is at the root level. + task_name (str): + Output only. The user specified name of the task that is + defined in [PipelineJob.spec][]. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Task create time. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Task start time. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Task end time. + executor_detail (google.cloud.aiplatform_v1beta1.types.PipelineTaskExecutorDetail): + Output only. The detailed execution info. + state (google.cloud.aiplatform_v1beta1.types.PipelineTaskDetail.State): + Output only. State of the task. + execution (google.cloud.aiplatform_v1beta1.types.Execution): + Output only. The execution metadata of the + task. + error (google.rpc.status_pb2.Status): + Output only. The error that occurred during + task execution. Only populated when the task's + state is FAILED or CANCELLED. + inputs (Sequence[google.cloud.aiplatform_v1beta1.types.PipelineTaskDetail.InputsEntry]): + Output only. The runtime input artifacts of + the task. + outputs (Sequence[google.cloud.aiplatform_v1beta1.types.PipelineTaskDetail.OutputsEntry]): + Output only. The runtime output artifacts of + the task. + """ + + class State(proto.Enum): + r"""Specifies state of TaskExecution""" + STATE_UNSPECIFIED = 0 + PENDING = 1 + RUNNING = 2 + SUCCEEDED = 3 + CANCEL_PENDING = 4 + CANCELLING = 5 + CANCELLED = 6 + FAILED = 7 + SKIPPED = 8 + NOT_TRIGGERED = 9 + + class ArtifactList(proto.Message): + r"""A list of artifact metadata. + + Attributes: + artifacts (Sequence[google.cloud.aiplatform_v1beta1.types.Artifact]): + Output only. A list of artifact metadata. + """ + + artifacts = proto.RepeatedField( + proto.MESSAGE, number=1, message=artifact.Artifact, + ) + + task_id = proto.Field(proto.INT64, number=1) + + parent_task_id = proto.Field(proto.INT64, number=12) + + task_name = proto.Field(proto.STRING, number=2) + + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + + start_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + executor_detail = proto.Field( + proto.MESSAGE, number=6, message="PipelineTaskExecutorDetail", + ) + + state = proto.Field(proto.ENUM, number=7, enum=State,) + + execution = proto.Field(proto.MESSAGE, number=8, message=gca_execution.Execution,) + + error = proto.Field(proto.MESSAGE, number=9, message=status.Status,) + + inputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=10, message=ArtifactList, + ) + + outputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=11, message=ArtifactList, + ) + + +class PipelineTaskExecutorDetail(proto.Message): + r"""The runtime detail of a pipeline executor. + + Attributes: + container_detail (google.cloud.aiplatform_v1beta1.types.PipelineTaskExecutorDetail.ContainerDetail): + Output only. The detailed info for a + container executor. + custom_job_detail (google.cloud.aiplatform_v1beta1.types.PipelineTaskExecutorDetail.CustomJobDetail): + Output only. The detailed info for a custom + job executor. + """ + + class ContainerDetail(proto.Message): + r"""The detail of a container execution. It contains the job + names of the lifecycle of a container execution. + + Attributes: + main_job (str): + Output only. The name of the + ``CustomJob`` for + the main container execution. + pre_caching_check_job (str): + Output only. The name of the + ``CustomJob`` for + the pre-caching-check container execution. This job will be + available if the + ``PipelineJob.pipeline_spec`` + specifies the ``pre_caching_check`` hook in the lifecycle + events. + """ + + main_job = proto.Field(proto.STRING, number=1) + + pre_caching_check_job = proto.Field(proto.STRING, number=2) + + class CustomJobDetail(proto.Message): + r"""The detailed info for a custom job executor. + + Attributes: + job (str): + Output only. The name of the + ``CustomJob``. + """ + + job = proto.Field(proto.STRING, number=1) + + container_detail = proto.Field( + proto.MESSAGE, number=1, oneof="details", message=ContainerDetail, + ) + + custom_job_detail = proto.Field( + proto.MESSAGE, number=2, oneof="details", message=CustomJobDetail, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index b06361dfa9..e1d66a4517 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -18,6 +18,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job from google.cloud.aiplatform_v1beta1.types import ( training_pipeline as gca_training_pipeline, ) @@ -33,6 +34,12 @@ "ListTrainingPipelinesResponse", "DeleteTrainingPipelineRequest", "CancelTrainingPipelineRequest", + "CreatePipelineJobRequest", + "GetPipelineJobRequest", + "ListPipelineJobsRequest", + "ListPipelineJobsResponse", + "DeletePipelineJobRequest", + "CancelPipelineJobRequest", }, ) @@ -171,4 +178,137 @@ class CancelTrainingPipelineRequest(proto.Message): name = proto.Field(proto.STRING, number=1) +class CreatePipelineJobRequest(proto.Message): + r"""Request message for + ``PipelineService.CreatePipelineJob``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + PipelineJob in. Format: + ``projects/{project}/locations/{location}`` + pipeline_job (google.cloud.aiplatform_v1beta1.types.PipelineJob): + Required. The PipelineJob to create. + pipeline_job_id (str): + The ID to use for the PipelineJob, which will become the + final component of the PipelineJob name. If not provided, an + ID will be automatically generated. + + This value should be less than 128 characters, and valid + characters are /[a-z][0-9]-/. + """ + + parent = proto.Field(proto.STRING, number=1) + + pipeline_job = proto.Field( + proto.MESSAGE, number=2, message=gca_pipeline_job.PipelineJob, + ) + + pipeline_job_id = proto.Field(proto.STRING, number=3) + + +class GetPipelineJobRequest(proto.Message): + r"""Request message for + ``PipelineService.GetPipelineJob``. + + Attributes: + name (str): + Required. The name of the PipelineJob resource. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListPipelineJobsRequest(proto.Message): + r"""Request message for + ``PipelineService.ListPipelineJobs``. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + PipelineJobs from. Format: + ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + - ``NOT display_name="my_pipeline"`` + - ``state="PIPELINE_STATE_FAILED"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListPipelineJobsResponse.next_page_token`` + of the previous + ``PipelineService.ListPipelineJobs`` + call. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + +class ListPipelineJobsResponse(proto.Message): + r"""Response message for + ``PipelineService.ListPipelineJobs`` + + Attributes: + pipeline_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.PipelineJob]): + List of PipelineJobs in the requested page. + next_page_token (str): + A token to retrieve the next page of results. Pass to + ``ListPipelineJobsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + pipeline_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_pipeline_job.PipelineJob, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeletePipelineJobRequest(proto.Message): + r"""Request message for + ``PipelineService.DeletePipelineJob``. + + Attributes: + name (str): + Required. The name of the PipelineJob resource to be + deleted. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CancelPipelineJobRequest(proto.Message): + r"""Request message for + ``PipelineService.CancelPipelineJob``. + + Attributes: + name (str): + Required. The name of the PipelineJob to cancel. Format: + ``projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard.py b/google/cloud/aiplatform_v1beta1/types/tensorboard.py new file mode 100644 index 0000000000..45db95e7fb --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", manifest={"Tensorboard",}, +) + + +class Tensorboard(proto.Message): + r"""Tensorboard is a physical database that stores users’ + training metrics. A default Tensorboard is provided in each + region of a GCP project. If needed users can also create extra + Tensorboards in their projects. + + Attributes: + name (str): + Output only. Name of the Tensorboard. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + display_name (str): + Required. User provided name of this + Tensorboard. + description (str): + Description of this Tensorboard. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for a + Tensorboard. If set, this Tensorboard and all + sub-resources of this Tensorboard will be + secured by this key. + blob_storage_path_prefix (str): + Output only. Consumer project Cloud Storage + path prefix used to store blob data, which can + either be a bucket or directory. Does not end + with a '/'. + run_count (int): + Output only. The number of Runs stored in + this Tensorboard. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Tensorboard + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Tensorboard + was last updated. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Tensorboard.LabelsEntry]): + The labels with user-defined metadata to + organize your Tensorboards. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. No more than 64 user labels can be + associated with one Tensorboard (System labels + are excluded). + + See https://goo.gl/xmQnxf for more information + and examples of labels. System reserved label + keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + etag (str): + Used to perform a consistent read-modify- + rite updates. If not set, a blind "overwrite" + update happens. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + encryption_spec = proto.Field( + proto.MESSAGE, number=11, message=gca_encryption_spec.EncryptionSpec, + ) + + blob_storage_path_prefix = proto.Field(proto.STRING, number=10) + + run_count = proto.Field(proto.INT32, number=5) + + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=8) + + etag = proto.Field(proto.STRING, number=9) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_data.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_data.py new file mode 100644 index 0000000000..9069e2ba30 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_data.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "TimeSeriesData", + "TimeSeriesDataPoint", + "Scalar", + "TensorboardTensor", + "TensorboardBlobSequence", + "TensorboardBlob", + }, +) + + +class TimeSeriesData(proto.Message): + r"""All the data stored in a TensorboardTimeSeries. + + Attributes: + tensorboard_time_series_id (str): + Required. The ID of the + TensorboardTimeSeries, which will become the + final component of the TensorboardTimeSeries' + resource name + value_type (google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries.ValueType): + Required. Immutable. The value type of this + time series. All the values in this time series + data must match this value type. + values (Sequence[google.cloud.aiplatform_v1beta1.types.TimeSeriesDataPoint]): + Required. Data points in this time series. + """ + + tensorboard_time_series_id = proto.Field(proto.STRING, number=1) + + value_type = proto.Field( + proto.ENUM, + number=2, + enum=tensorboard_time_series.TensorboardTimeSeries.ValueType, + ) + + values = proto.RepeatedField( + proto.MESSAGE, number=3, message="TimeSeriesDataPoint", + ) + + +class TimeSeriesDataPoint(proto.Message): + r"""A TensorboardTimeSeries data point. + + Attributes: + scalar (google.cloud.aiplatform_v1beta1.types.Scalar): + A scalar value. + tensor (google.cloud.aiplatform_v1beta1.types.TensorboardTensor): + A tensor value. + blobs (google.cloud.aiplatform_v1beta1.types.TensorboardBlobSequence): + A blob sequence value. + wall_time (google.protobuf.timestamp_pb2.Timestamp): + Wall clock timestamp when this data point is + generated by the end user. + step (int): + Step index of this data point within the run. + """ + + scalar = proto.Field(proto.MESSAGE, number=3, oneof="value", message="Scalar",) + + tensor = proto.Field( + proto.MESSAGE, number=4, oneof="value", message="TensorboardTensor", + ) + + blobs = proto.Field( + proto.MESSAGE, number=5, oneof="value", message="TensorboardBlobSequence", + ) + + wall_time = proto.Field(proto.MESSAGE, number=1, message=timestamp.Timestamp,) + + step = proto.Field(proto.INT64, number=2) + + +class Scalar(proto.Message): + r"""One point viewable on a scalar metric plot. + + Attributes: + value (float): + Value of the point at this step / timestamp. + """ + + value = proto.Field(proto.DOUBLE, number=1) + + +class TensorboardTensor(proto.Message): + r"""One point viewable on a tensor metric plot. + + Attributes: + value (bytes): + Required. Serialized form of + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto + version_number (int): + Optional. Version number of TensorProto used to serialize + ``value``. + """ + + value = proto.Field(proto.BYTES, number=1) + + version_number = proto.Field(proto.INT32, number=2) + + +class TensorboardBlobSequence(proto.Message): + r"""One point viewable on a blob metric plot, but mostly just a wrapper + message to work around repeated fields can't be used directly within + ``oneof`` fields. + + Attributes: + values (Sequence[google.cloud.aiplatform_v1beta1.types.TensorboardBlob]): + List of blobs contained within the sequence. + """ + + values = proto.RepeatedField(proto.MESSAGE, number=1, message="TensorboardBlob",) + + +class TensorboardBlob(proto.Message): + r"""One blob (e.g, image, graph) viewable on a blob metric plot. + + Attributes: + id (str): + Output only. A URI safe key uniquely + identifying a blob. Can be used to locate the + blob stored in the Cloud Storage bucket of the + consumer project. + data (bytes): + Optional. The bytes of the blob is not + present unless it's returned by the + ReadTensorboardBlobData endpoint. + """ + + id = proto.Field(proto.STRING, number=1) + + data = proto.Field(proto.BYTES, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py new file mode 100644 index 0000000000..6c073aa5e8 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", manifest={"TensorboardExperiment",}, +) + + +class TensorboardExperiment(proto.Message): + r"""A TensorboardExperiment is a group of TensorboardRuns, that + are typically the results of a training job run, in a + Tensorboard. + + Attributes: + name (str): + Output only. Name of the TensorboardExperiment. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + display_name (str): + User provided name of this + TensorboardExperiment. + description (str): + Description of this TensorboardExperiment. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + TensorboardExperiment was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + TensorboardExperiment was last updated. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.TensorboardExperiment.LabelsEntry]): + The labels with user-defined metadata to organize your + Datasets. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. No more than 64 user labels can be + associated with one Dataset (System labels are excluded). + + See https://goo.gl/xmQnxf for more information and examples + of labels. System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. Following + system labels exist for each Dataset: + + - "aiplatform.googleapis.com/dataset_metadata_schema": + + - output only, its value is the + [metadata_schema's][metadata_schema_uri] title. + etag (str): + Used to perform consistent read-modify-write + updates. If not set, a blind "overwrite" update + happens. + source (str): + Immutable. Source of the + TensorboardExperiment. Example: a custom + training job. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=6) + + etag = proto.Field(proto.STRING, number=7) + + source = proto.Field(proto.STRING, number=8) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_run.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_run.py new file mode 100644 index 0000000000..f9cff272c4 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_run.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", manifest={"TensorboardRun",}, +) + + +class TensorboardRun(proto.Message): + r"""TensorboardRun maps to a specific execution of a training job + with a given set of hyperparameter values, model definition, + dataset, etc + + Attributes: + name (str): + Output only. Name of the TensorboardRun. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + display_name (str): + Required. User provided name of this + TensorboardRun. This value must be unique among + all TensorboardRuns belonging to the same parent + TensorboardExperiment. + description (str): + Description of this TensorboardRun. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + TensorboardRun was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + TensorboardRun was last updated. + labels (Sequence[google.cloud.aiplatform_v1beta1.types.TensorboardRun.LabelsEntry]): + + etag (str): + Used to perform a consistent read-modify- + rite updates. If not set, a blind "overwrite" + update happens. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=8) + + etag = proto.Field(proto.STRING, number=9) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_service.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_service.py new file mode 100644 index 0000000000..f208059201 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_service.py @@ -0,0 +1,892 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import operation +from google.cloud.aiplatform_v1beta1.types import tensorboard as gca_tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_data +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_experiment as gca_tensorboard_experiment, +) +from google.cloud.aiplatform_v1beta1.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_time_series as gca_tensorboard_time_series, +) +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "CreateTensorboardRequest", + "GetTensorboardRequest", + "ListTensorboardsRequest", + "ListTensorboardsResponse", + "UpdateTensorboardRequest", + "DeleteTensorboardRequest", + "CreateTensorboardExperimentRequest", + "GetTensorboardExperimentRequest", + "ListTensorboardExperimentsRequest", + "ListTensorboardExperimentsResponse", + "UpdateTensorboardExperimentRequest", + "DeleteTensorboardExperimentRequest", + "CreateTensorboardRunRequest", + "GetTensorboardRunRequest", + "ReadTensorboardBlobDataRequest", + "ReadTensorboardBlobDataResponse", + "ListTensorboardRunsRequest", + "ListTensorboardRunsResponse", + "UpdateTensorboardRunRequest", + "DeleteTensorboardRunRequest", + "CreateTensorboardTimeSeriesRequest", + "GetTensorboardTimeSeriesRequest", + "ListTensorboardTimeSeriesRequest", + "ListTensorboardTimeSeriesResponse", + "UpdateTensorboardTimeSeriesRequest", + "DeleteTensorboardTimeSeriesRequest", + "ReadTensorboardTimeSeriesDataRequest", + "ReadTensorboardTimeSeriesDataResponse", + "WriteTensorboardRunDataRequest", + "WriteTensorboardRunDataResponse", + "ExportTensorboardTimeSeriesDataRequest", + "ExportTensorboardTimeSeriesDataResponse", + "CreateTensorboardOperationMetadata", + "UpdateTensorboardOperationMetadata", + }, +) + + +class CreateTensorboardRequest(proto.Message): + r"""Request message for + ``TensorboardService.CreateTensorboard``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + Tensorboard in. Format: + ``projects/{project}/locations/{location}`` + tensorboard (google.cloud.aiplatform_v1beta1.types.Tensorboard): + Required. The Tensorboard to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + tensorboard = proto.Field( + proto.MESSAGE, number=2, message=gca_tensorboard.Tensorboard, + ) + + +class GetTensorboardRequest(proto.Message): + r"""Request message for + ``TensorboardService.GetTensorboard``. + + Attributes: + name (str): + Required. The name of the Tensorboard resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListTensorboardsRequest(proto.Message): + r"""Request message for + ``TensorboardService.ListTensorboards``. + + Attributes: + parent (str): + Required. The resource name of the Location + to list Tensorboards. Format: + 'projects/{project}/locations/{location}' + filter (str): + Lists the Tensorboards that match the filter + expression. + page_size (int): + The maximum number of Tensorboards to return. + The service may return fewer than this value. If + unspecified, at most 100 Tensorboards will be + returned. The maximum value is 100; values above + 100 will be coerced to 100. + page_token (str): + A page token, received from a previous + ``TensorboardService.ListTensorboards`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``TensorboardService.ListTensorboards`` + must match the call that provided the page token. + order_by (str): + Field to use to sort the list. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + read_mask = proto.Field(proto.MESSAGE, number=6, message=field_mask.FieldMask,) + + +class ListTensorboardsResponse(proto.Message): + r"""Response message for + ``TensorboardService.ListTensorboards``. + + Attributes: + tensorboards (Sequence[google.cloud.aiplatform_v1beta1.types.Tensorboard]): + The Tensorboards mathching the request. + next_page_token (str): + A token, which can be sent as + ``ListTensorboardsRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + tensorboards = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_tensorboard.Tensorboard, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateTensorboardRequest(proto.Message): + r"""Request message for + ``TensorboardService.UpdateTensorboard``. + + Attributes: + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the Tensorboard resource by the update. The + fields specified in the update_mask are relative to the + resource, not the full request. A field will be overwritten + if it is in the mask. If the user does not provide a mask + then all fields will be overwritten if new values are + specified. + tensorboard (google.cloud.aiplatform_v1beta1.types.Tensorboard): + Required. The Tensorboard's ``name`` field is used to + identify the Tensorboard to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + """ + + update_mask = proto.Field(proto.MESSAGE, number=1, message=field_mask.FieldMask,) + + tensorboard = proto.Field( + proto.MESSAGE, number=2, message=gca_tensorboard.Tensorboard, + ) + + +class DeleteTensorboardRequest(proto.Message): + r"""Request message for + ``TensorboardService.DeleteTensorboard``. + + Attributes: + name (str): + Required. The name of the Tensorboard to be deleted. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateTensorboardExperimentRequest(proto.Message): + r"""Request message for + ``TensorboardService.CreateTensorboardExperiment``. + + Attributes: + parent (str): + Required. The resource name of the Tensorboard to create the + TensorboardExperiment in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + tensorboard_experiment (google.cloud.aiplatform_v1beta1.types.TensorboardExperiment): + The TensorboardExperiment to create. + tensorboard_experiment_id (str): + Required. The ID to use for the Tensorboard experiment, + which will become the final component of the Tensorboard + experiment's resource name. + + This value should be 1-128 characters, and valid characters + are /[a-z][0-9]-/. + """ + + parent = proto.Field(proto.STRING, number=1) + + tensorboard_experiment = proto.Field( + proto.MESSAGE, + number=2, + message=gca_tensorboard_experiment.TensorboardExperiment, + ) + + tensorboard_experiment_id = proto.Field(proto.STRING, number=3) + + +class GetTensorboardExperimentRequest(proto.Message): + r"""Request message for + ``TensorboardService.GetTensorboardExperiment``. + + Attributes: + name (str): + Required. The name of the TensorboardExperiment resource. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListTensorboardExperimentsRequest(proto.Message): + r"""Request message for + ``TensorboardService.ListTensorboardExperiments``. + + Attributes: + parent (str): + Required. The resource name of the + Tensorboard to list TensorboardExperiments. + Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}' + filter (str): + Lists the TensorboardExperiments that match + the filter expression. + page_size (int): + The maximum number of TensorboardExperiments + to return. The service may return fewer than + this value. If unspecified, at most 50 + TensorboardExperiments will be returned. The + maximum value is 1000; values above 1000 will be + coerced to 1000. + page_token (str): + A page token, received from a previous + ``TensorboardService.ListTensorboardExperiments`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``TensorboardService.ListTensorboardExperiments`` + must match the call that provided the page token. + order_by (str): + Field to use to sort the list. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + read_mask = proto.Field(proto.MESSAGE, number=6, message=field_mask.FieldMask,) + + +class ListTensorboardExperimentsResponse(proto.Message): + r"""Response message for + ``TensorboardService.ListTensorboardExperiments``. + + Attributes: + tensorboard_experiments (Sequence[google.cloud.aiplatform_v1beta1.types.TensorboardExperiment]): + The TensorboardExperiments mathching the + request. + next_page_token (str): + A token, which can be sent as + ``ListTensorboardExperimentsRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + tensorboard_experiments = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gca_tensorboard_experiment.TensorboardExperiment, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateTensorboardExperimentRequest(proto.Message): + r"""Request message for + ``TensorboardService.UpdateTensorboardExperiment``. + + Attributes: + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardExperiment resource by the + update. The fields specified in the update_mask are relative + to the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then all fields will be overwritten if new + values are specified. + tensorboard_experiment (google.cloud.aiplatform_v1beta1.types.TensorboardExperiment): + Required. The TensorboardExperiment's ``name`` field is used + to identify the TensorboardExperiment to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + """ + + update_mask = proto.Field(proto.MESSAGE, number=1, message=field_mask.FieldMask,) + + tensorboard_experiment = proto.Field( + proto.MESSAGE, + number=2, + message=gca_tensorboard_experiment.TensorboardExperiment, + ) + + +class DeleteTensorboardExperimentRequest(proto.Message): + r"""Request message for + ``TensorboardService.DeleteTensorboardExperiment``. + + Attributes: + name (str): + Required. The name of the TensorboardExperiment to be + deleted. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateTensorboardRunRequest(proto.Message): + r"""Request message for + ``TensorboardService.CreateTensorboardRun``. + + Attributes: + parent (str): + Required. The resource name of the Tensorboard to create the + TensorboardRun in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}`` + tensorboard_run (google.cloud.aiplatform_v1beta1.types.TensorboardRun): + Required. The TensorboardRun to create. + tensorboard_run_id (str): + Required. The ID to use for the Tensorboard run, which will + become the final component of the Tensorboard run's resource + name. + + This value should be 1-128 characters, and valid characters + are /[a-z][0-9]-/. + """ + + parent = proto.Field(proto.STRING, number=1) + + tensorboard_run = proto.Field( + proto.MESSAGE, number=2, message=gca_tensorboard_run.TensorboardRun, + ) + + tensorboard_run_id = proto.Field(proto.STRING, number=3) + + +class GetTensorboardRunRequest(proto.Message): + r"""Request message for + ``TensorboardService.GetTensorboardRun``. + + Attributes: + name (str): + Required. The name of the TensorboardRun resource. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ReadTensorboardBlobDataRequest(proto.Message): + r"""Request message for + ``TensorboardService.ReadTensorboardBlobData``. + + Attributes: + time_series (str): + Required. The resource name of the TensorboardTimeSeries to + list Blobs. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}' + blob_ids (Sequence[str]): + IDs of the blobs to read. + """ + + time_series = proto.Field(proto.STRING, number=1) + + blob_ids = proto.RepeatedField(proto.STRING, number=2) + + +class ReadTensorboardBlobDataResponse(proto.Message): + r"""Response message for + ``TensorboardService.ReadTensorboardBlobData``. + + Attributes: + blobs (Sequence[google.cloud.aiplatform_v1beta1.types.TensorboardBlob]): + Blob messages containing blob bytes. + """ + + blobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=tensorboard_data.TensorboardBlob, + ) + + +class ListTensorboardRunsRequest(proto.Message): + r"""Request message for + ``TensorboardService.ListTensorboardRuns``. + + Attributes: + parent (str): + Required. The resource name of the + Tensorboard to list TensorboardRuns. Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}' + filter (str): + Lists the TensorboardRuns that match the + filter expression. + page_size (int): + The maximum number of TensorboardRuns to + return. The service may return fewer than this + value. If unspecified, at most 50 + TensorboardRuns will be returned. The maximum + value is 1000; values above 1000 will be coerced + to 1000. + page_token (str): + A page token, received from a previous + ``TensorboardService.ListTensorboardRuns`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``TensorboardService.ListTensorboardRuns`` + must match the call that provided the page token. + order_by (str): + Field to use to sort the list. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + read_mask = proto.Field(proto.MESSAGE, number=6, message=field_mask.FieldMask,) + + +class ListTensorboardRunsResponse(proto.Message): + r"""Response message for + ``TensorboardService.ListTensorboardRuns``. + + Attributes: + tensorboard_runs (Sequence[google.cloud.aiplatform_v1beta1.types.TensorboardRun]): + The TensorboardRuns mathching the request. + next_page_token (str): + A token, which can be sent as + ``ListTensorboardRunsRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + tensorboard_runs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_tensorboard_run.TensorboardRun, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateTensorboardRunRequest(proto.Message): + r"""Request message for + ``TensorboardService.UpdateTensorboardRun``. + + Attributes: + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardRun resource by the update. + The fields specified in the update_mask are relative to the + resource, not the full request. A field will be overwritten + if it is in the mask. If the user does not provide a mask + then all fields will be overwritten if new values are + specified. + tensorboard_run (google.cloud.aiplatform_v1beta1.types.TensorboardRun): + Required. The TensorboardRun's ``name`` field is used to + identify the TensorboardRun to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + """ + + update_mask = proto.Field(proto.MESSAGE, number=1, message=field_mask.FieldMask,) + + tensorboard_run = proto.Field( + proto.MESSAGE, number=2, message=gca_tensorboard_run.TensorboardRun, + ) + + +class DeleteTensorboardRunRequest(proto.Message): + r"""Request message for + ``TensorboardService.DeleteTensorboardRun``. + + Attributes: + name (str): + Required. The name of the TensorboardRun to be deleted. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateTensorboardTimeSeriesRequest(proto.Message): + r"""Request message for + ``TensorboardService.CreateTensorboardTimeSeries``. + + Attributes: + parent (str): + Required. The resource name of the TensorboardRun to create + the TensorboardTimeSeries in. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + tensorboard_time_series_id (str): + Optional. The user specified unique ID to use for the + TensorboardTimeSeries, which will become the final component + of the TensorboardTimeSeries's resource name. Ref: + go/ucaip-user-specified-id + + This value should match "[a-z0-9][a-z0-9-]{0, 127}". + tensorboard_time_series (google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries): + Required. The TensorboardTimeSeries to + create. + """ + + parent = proto.Field(proto.STRING, number=1) + + tensorboard_time_series_id = proto.Field(proto.STRING, number=3) + + tensorboard_time_series = proto.Field( + proto.MESSAGE, + number=2, + message=gca_tensorboard_time_series.TensorboardTimeSeries, + ) + + +class GetTensorboardTimeSeriesRequest(proto.Message): + r"""Request message for + ``TensorboardService.GetTensorboardTimeSeries``. + + Attributes: + name (str): + Required. The name of the TensorboardTimeSeries resource. + Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListTensorboardTimeSeriesRequest(proto.Message): + r"""Request message for + ``TensorboardService.ListTensorboardTimeSeries``. + + Attributes: + parent (str): + Required. The resource name of the + TensorboardRun to list TensorboardTimeSeries. + Format: + 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}' + filter (str): + Lists the TensorboardTimeSeries that match + the filter expression. + page_size (int): + The maximum number of TensorboardTimeSeries + to return. The service may return fewer than + this value. If unspecified, at most 50 + TensorboardTimeSeries will be returned. The + maximum value is 1000; values above 1000 will be + coerced to 1000. + page_token (str): + A page token, received from a previous + ``TensorboardService.ListTensorboardTimeSeries`` + call. Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + ``TensorboardService.ListTensorboardTimeSeries`` + must match the call that provided the page token. + order_by (str): + Field to use to sort the list. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + read_mask = proto.Field(proto.MESSAGE, number=6, message=field_mask.FieldMask,) + + +class ListTensorboardTimeSeriesResponse(proto.Message): + r"""Response message for + ``TensorboardService.ListTensorboardTimeSeries``. + + Attributes: + tensorboard_time_series (Sequence[google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries]): + The TensorboardTimeSeries mathching the + request. + next_page_token (str): + A token, which can be sent as + ``ListTensorboardTimeSeriesRequest.page_token`` + to retrieve the next page. If this field is omitted, there + are no subsequent pages. + """ + + @property + def raw_page(self): + return self + + tensorboard_time_series = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gca_tensorboard_time_series.TensorboardTimeSeries, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateTensorboardTimeSeriesRequest(proto.Message): + r"""Request message for + ``TensorboardService.UpdateTensorboardTimeSeries``. + + Attributes: + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Field mask is used to specify the fields to be + overwritten in the TensorboardTimeSeries resource by the + update. The fields specified in the update_mask are relative + to the resource, not the full request. A field will be + overwritten if it is in the mask. If the user does not + provide a mask then all fields will be overwritten if new + values are specified. + tensorboard_time_series (google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries): + Required. The TensorboardTimeSeries' ``name`` field is used + to identify the TensorboardTimeSeries to be updated. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + """ + + update_mask = proto.Field(proto.MESSAGE, number=1, message=field_mask.FieldMask,) + + tensorboard_time_series = proto.Field( + proto.MESSAGE, + number=2, + message=gca_tensorboard_time_series.TensorboardTimeSeries, + ) + + +class DeleteTensorboardTimeSeriesRequest(proto.Message): + r"""Request message for + ``TensorboardService.DeleteTensorboardTimeSeries``. + + Attributes: + name (str): + Required. The name of the TensorboardTimeSeries to be + deleted. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ReadTensorboardTimeSeriesDataRequest(proto.Message): + r"""Request message for + ``TensorboardService.ReadTensorboardTimeSeriesData``. + + Attributes: + tensorboard_time_series (str): + Required. The resource name of the TensorboardTimeSeries to + read data from. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + max_data_points (int): + The maximum number of TensorboardTimeSeries' + data to return. + This value should be a positive integer. + This value can be set to -1 to return all data. + filter (str): + Reads the TensorboardTimeSeries' data that + match the filter expression. + """ + + tensorboard_time_series = proto.Field(proto.STRING, number=1) + + max_data_points = proto.Field(proto.INT32, number=2) + + filter = proto.Field(proto.STRING, number=3) + + +class ReadTensorboardTimeSeriesDataResponse(proto.Message): + r"""Response message for + ``TensorboardService.ReadTensorboardTimeSeriesData``. + + Attributes: + time_series_data (google.cloud.aiplatform_v1beta1.types.TimeSeriesData): + The returned time series data. + """ + + time_series_data = proto.Field( + proto.MESSAGE, number=1, message=tensorboard_data.TimeSeriesData, + ) + + +class WriteTensorboardRunDataRequest(proto.Message): + r"""Request message for + ``TensorboardService.WriteTensorboardRunData``. + + Attributes: + tensorboard_run (str): + Required. The resource name of the TensorboardRun to write + data to. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}`` + time_series_data (Sequence[google.cloud.aiplatform_v1beta1.types.TimeSeriesData]): + Required. The TensorboardTimeSeries data to + write. Values with in a time series are indexed + by their step value. Repeated writes to the same + step will overwrite the existing value for that + step. + The upper limit of data points per write request + is 5000. + """ + + tensorboard_run = proto.Field(proto.STRING, number=1) + + time_series_data = proto.RepeatedField( + proto.MESSAGE, number=2, message=tensorboard_data.TimeSeriesData, + ) + + +class WriteTensorboardRunDataResponse(proto.Message): + r"""Response message for + ``TensorboardService.WriteTensorboardRunData``. + """ + + +class ExportTensorboardTimeSeriesDataRequest(proto.Message): + r"""Request message for + ``TensorboardService.ExportTensorboardTimeSeriesData``. + + Attributes: + tensorboard_time_series (str): + Required. The resource name of the TensorboardTimeSeries to + export data from. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}`` + filter (str): + Exports the TensorboardTimeSeries' data that + match the filter expression. + page_size (int): + The maximum number of data points to return per page. The + default page_size will be 1000. Values must be between 1 and + 10000. Values above 10000 will be coerced to 10000. + page_token (str): + A page token, received from a previous + [TensorboardService.ExportTensorboardTimeSeries][] call. + Provide this to retrieve the subsequent page. + + When paginating, all other parameters provided to + [TensorboardService.ExportTensorboardTimeSeries][] must + match the call that provided the page token. + order_by (str): + Field to use to sort the + TensorboardTimeSeries' data. By default, + TensorboardTimeSeries' data will be returned in + a pseudo random order. + """ + + tensorboard_time_series = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + order_by = proto.Field(proto.STRING, number=5) + + +class ExportTensorboardTimeSeriesDataResponse(proto.Message): + r"""Response message for + ``TensorboardService.ExportTensorboardTimeSeriesData``. + + Attributes: + time_series_data_points (Sequence[google.cloud.aiplatform_v1beta1.types.TimeSeriesDataPoint]): + The returned time series data points. + next_page_token (str): + A token, which can be sent as + [ExportTensorboardTimeSeriesRequest.page_token][] to + retrieve the next page. If this field is omitted, there are + no subsequent pages. + """ + + @property + def raw_page(self): + return self + + time_series_data_points = proto.RepeatedField( + proto.MESSAGE, number=1, message=tensorboard_data.TimeSeriesDataPoint, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class CreateTensorboardOperationMetadata(proto.Message): + r"""Details of operations that perform create Tensorboard. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Tensorboard. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +class UpdateTensorboardOperationMetadata(proto.Message): + r"""Details of operations that perform update Tensorboard. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for Tensorboard. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py new file mode 100644 index 0000000000..47a66d38f6 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_time_series.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", manifest={"TensorboardTimeSeries",}, +) + + +class TensorboardTimeSeries(proto.Message): + r"""TensorboardTimeSeries maps to times series produced in + training runs + + Attributes: + name (str): + Output only. Name of the + TensorboardTimeSeries. + display_name (str): + Required. User provided name of this + TensorboardTimeSeries. This value should be + unique among all TensorboardTimeSeries resources + belonging to the same TensorboardRun resource + (parent resource). + description (str): + Description of this TensorboardTimeSeries. + value_type (google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries.ValueType): + Required. Immutable. Type of + TensorboardTimeSeries value. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + TensorboardTimeSeries was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + TensorboardTimeSeries was last updated. + etag (str): + Used to perform a consistent read-modify- + rite updates. If not set, a blind "overwrite" + update happens. + plugin_name (str): + Immutable. Name of the plugin this time + series pertain to. Such as Scalar, Tensor, Blob + plugin_data (bytes): + Data of the current plugin, with the size + limited to 65KB. + metadata (google.cloud.aiplatform_v1beta1.types.TensorboardTimeSeries.Metadata): + Output only. Scalar, Tensor, or Blob metadata + for this TensorboardTimeSeries. + """ + + class ValueType(proto.Enum): + r"""An enum representing the value type of a + TensorboardTimeSeries. + """ + VALUE_TYPE_UNSPECIFIED = 0 + SCALAR = 1 + TENSOR = 2 + BLOB_SEQUENCE = 3 + + class Metadata(proto.Message): + r"""Describes metadata for a TensorboardTimeSeries. + + Attributes: + max_step (int): + Output only. Max step index of all data + points within a TensorboardTimeSeries. + max_wall_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Max wall clock timestamp of all + data points within a TensorboardTimeSeries. + max_blob_sequence_length (int): + Output only. The largest blob sequence length (number of + blobs) of all data points in this time series, if its + ValueType is BLOB_SEQUENCE. + """ + + max_step = proto.Field(proto.INT64, number=1) + + max_wall_time = proto.Field( + proto.MESSAGE, number=2, message=timestamp.Timestamp, + ) + + max_blob_sequence_length = proto.Field(proto.INT64, number=3) + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + value_type = proto.Field(proto.ENUM, number=4, enum=ValueType,) + + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + etag = proto.Field(proto.STRING, number=7) + + plugin_name = proto.Field(proto.STRING, number=8) + + plugin_data = proto.Field(proto.BYTES, number=9) + + metadata = proto.Field(proto.MESSAGE, number=10, message=Metadata,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/types.py b/google/cloud/aiplatform_v1beta1/types/types.py index 127833d18e..53581d3bdb 100644 --- a/google/cloud/aiplatform_v1beta1/types/types.py +++ b/google/cloud/aiplatform_v1beta1/types/types.py @@ -25,7 +25,7 @@ class BoolArray(proto.Message): - r"""Bool list type feature value. + r"""A list of boolean values. Attributes: values (Sequence[bool]): @@ -36,7 +36,7 @@ class BoolArray(proto.Message): class DoubleArray(proto.Message): - r"""Double list type feature value. + r"""A list of double values. Attributes: values (Sequence[float]): @@ -47,7 +47,7 @@ class DoubleArray(proto.Message): class Int64Array(proto.Message): - r"""Int64 list type feature value. + r"""A list of int64 values. Attributes: values (Sequence[int]): diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index 9b5532b2e0..7c51035fbf 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -41,7 +41,7 @@ class UserActionReference(proto.Message): method (str): The method name of the API RPC call. For example, - "/google.cloud.aiplatform.master.DatasetService.CreateDataset". + "/google.cloud.aiplatform.{apiVersion}.DatasetService.CreateDataset". """ operation = proto.Field(proto.STRING, number=1, oneof="reference") diff --git a/google/cloud/aiplatform_v1beta1/types/value.py b/google/cloud/aiplatform_v1beta1/types/value.py new file mode 100644 index 0000000000..fe79c9e2e8 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/value.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", manifest={"Value",}, +) + + +class Value(proto.Message): + r"""Value is the value of the field. + + Attributes: + int_value (int): + An integer value. + double_value (float): + A double value. + string_value (str): + A string value. + """ + + int_value = proto.Field(proto.INT64, number=1, oneof="value") + + double_value = proto.Field(proto.DOUBLE, number=2, oneof="value") + + string_value = proto.Field(proto.STRING, number=3, oneof="value") + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/noxfile.py b/noxfile.py index 0f27c1ff88..7b28a76f53 100644 --- a/noxfile.py +++ b/noxfile.py @@ -131,9 +131,6 @@ def system(session): # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": session.skip("RUN_SYSTEM_TESTS is set to false, skipping") - # Sanity check: Only run tests if the environment variable is set. - if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): - session.skip("Credentials must be set via environment variable") # Install pyopenssl for mTLS testing. if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": session.install("pyopenssl") diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 580c6a962d..112d5c200b 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -130,25 +130,83 @@ def mock_import_text_dataset(mock_text_dataset): yield mock -""" ----------------------------------------------------------------------------- -TrainingJob Fixtures ----------------------------------------------------------------------------- -""" +@pytest.fixture +def mock_import_video_data(mock_video_dataset): + with patch.object(mock_video_dataset, "import_data") as mock: + yield mock + + +# ---------------------------------------------------------------------------- +# TrainingJob Fixtures +# ---------------------------------------------------------------------------- @pytest.fixture -def mock_init_automl_image_training_job(): - with patch.object( - aiplatform.training_jobs.AutoMLImageTrainingJob, "__init__" - ) as mock: +def mock_custom_training_job(): + mock = MagicMock(aiplatform.training_jobs.CustomTrainingJob) + yield mock + + +@pytest.fixture +def mock_image_training_job(): + mock = MagicMock(aiplatform.training_jobs.AutoMLImageTrainingJob) + yield mock + + +@pytest.fixture +def mock_tabular_training_job(): + mock = MagicMock(aiplatform.training_jobs.AutoMLTabularTrainingJob) + yield mock + + +@pytest.fixture +def mock_text_training_job(): + mock = MagicMock(aiplatform.training_jobs.AutoMLTextTrainingJob) + yield mock + + +@pytest.fixture +def mock_video_training_job(): + mock = MagicMock(aiplatform.training_jobs.AutoMLVideoTrainingJob) + yield mock + + +@pytest.fixture +def mock_get_automl_tabular_training_job(mock_tabular_training_job): + with patch.object(aiplatform, "AutoMLTabularTrainingJob") as mock: + mock.return_value = mock_tabular_training_job + yield mock + + +@pytest.fixture +def mock_run_automl_tabular_training_job(mock_tabular_training_job): + with patch.object(mock_tabular_training_job, "run") as mock: + yield mock + + +@pytest.fixture +def mock_get_automl_image_training_job(mock_image_training_job): + with patch.object(aiplatform, "AutoMLImageTrainingJob") as mock: + mock.return_value = mock_image_training_job + yield mock + + +@pytest.fixture +def mock_run_automl_image_training_job(mock_image_training_job): + with patch.object(mock_image_training_job, "run") as mock: + yield mock + + +@pytest.fixture +def mock_init_custom_training_job(): + with patch.object(aiplatform.training_jobs.CustomTrainingJob, "__init__") as mock: mock.return_value = None yield mock @pytest.fixture -def mock_run_automl_image_training_job(): - with patch.object(aiplatform.training_jobs.AutoMLImageTrainingJob, "run") as mock: +def mock_run_custom_training_job(): + with patch.object(aiplatform.training_jobs.CustomTrainingJob, "run") as mock: yield mock @@ -160,15 +218,34 @@ def mock_run_automl_image_training_job(): @pytest.fixture -def mock_init_model(): - with patch.object(aiplatform.models.Model, "__init__") as mock: - mock.return_value = None +def mock_model(): + mock = MagicMock(aiplatform.models.Model) + yield mock + + +@pytest.fixture +def mock_init_model(mock_model): + with patch.object(aiplatform, "Model") as mock: + mock.return_value = mock_model + yield mock + + +@pytest.fixture +def mock_batch_predict_model(mock_model): + with patch.object(mock_model, "batch_predict") as mock: yield mock @pytest.fixture -def mock_batch_predict_model(): - with patch.object(aiplatform.models.Model, "batch_predict") as mock: +def mock_upload_model(): + with patch.object(aiplatform.models.Model, "upload") as mock: + yield mock + + +@pytest.fixture +def mock_deploy_model(mock_model, mock_endpoint): + with patch.object(mock_model, "deploy") as mock: + mock.return_value = mock_endpoint yield mock @@ -198,6 +275,12 @@ def mock_endpoint(): yield mock +@pytest.fixture +def mock_create_endpoint(): + with patch.object(aiplatform.Endpoint, "create") as mock: + yield mock + + @pytest.fixture def mock_get_endpoint(mock_endpoint): with patch.object(aiplatform, "Endpoint") as mock_get_endpoint: diff --git a/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py b/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py new file mode 100644 index 0000000000..b7f4ea8013 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py @@ -0,0 +1,36 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample] +def create_and_import_dataset_tabular_bigquery_sample( + display_name: str, project: str, location: str, bigquery_source: str, +): + + aiplatform.init(project=project, location=location) + + dataset = aiplatform.TabularDataset.create( + display_name=display_name, bigquery_source=bigquery_source, + ) + + dataset.wait() + + print(f'\tDataset: "{dataset.display_name}"') + print(f'\tname: "{dataset.resource_name}"') + + +# [END aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample] diff --git a/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample_test.py b/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample_test.py new file mode 100644 index 0000000000..6eefcf7702 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample_test.py @@ -0,0 +1,36 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_and_import_dataset_tabular_bigquery_sample +import test_constants as constants + + +def test_create_and_import_dataset_tabular_bigquery_sample( + mock_sdk_init, mock_create_tabular_dataset +): + + create_and_import_dataset_tabular_bigquery_sample.create_and_import_dataset_tabular_bigquery_sample( + project=constants.PROJECT, + location=constants.LOCATION, + bigquery_source=constants.BIGQUERY_SOURCE, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_tabular_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, bigquery_source=constants.BIGQUERY_SOURCE, + ) diff --git a/samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py b/samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py new file mode 100644 index 0000000000..cac7a64d89 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_tabular_gcs_sample] +def create_and_import_dataset_tabular_gcs_sample( + display_name: str, project: str, location: str, gcs_source: Union[str, List[str]], +): + + aiplatform.init(project=project, location=location) + + dataset = aiplatform.TabularDataset.create( + display_name=display_name, gcs_source=gcs_source, + ) + + dataset.wait() + + print(f'\tDataset: "{dataset.display_name}"') + print(f'\tname: "{dataset.resource_name}"') + + +# [END aiplatform_sdk_create_and_import_dataset_tabular_gcs_sample] diff --git a/samples/model-builder/create_and_import_dataset_tabular_gcs_sample_test.py b/samples/model-builder/create_and_import_dataset_tabular_gcs_sample_test.py new file mode 100644 index 0000000000..ca8679be01 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_tabular_gcs_sample_test.py @@ -0,0 +1,36 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_and_import_dataset_tabular_gcs_sample +import test_constants as constants + + +def test_create_and_import_dataset_tabular_gcs_sample( + mock_sdk_init, mock_create_tabular_dataset +): + + create_and_import_dataset_tabular_gcs_sample.create_and_import_dataset_tabular_gcs_sample( + project=constants.PROJECT, + location=constants.LOCATION, + gcs_source=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_tabular_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, gcs_source=constants.GCS_SOURCES, + ) diff --git a/samples/model-builder/create_and_import_dataset_video_sample.py b/samples/model-builder/create_and_import_dataset_video_sample.py new file mode 100644 index 0000000000..60cfd65f7a --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_video_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_video_sample] +def create_and_import_dataset_video_sample( + project: str, + location: str, + display_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.VideoDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.video.classification, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_create_and_import_dataset_video_sample] diff --git a/samples/model-builder/create_and_import_dataset_video_sample_test.py b/samples/model-builder/create_and_import_dataset_video_sample_test.py new file mode 100644 index 0000000000..1ebbc7a3d0 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_video_sample_test.py @@ -0,0 +1,42 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import create_and_import_dataset_video_sample + +import test_constants as constants + + +def test_create_and_import_dataset_video_sample( + mock_sdk_init, mock_create_video_dataset +): + + create_and_import_dataset_video_sample.create_and_import_dataset_video_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_video_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.video.classification, + sync=True, + ) diff --git a/samples/model-builder/create_batch_prediction_job_sample_test.py b/samples/model-builder/create_batch_prediction_job_sample_test.py index f39c1020b5..a3eb5ed085 100644 --- a/samples/model-builder/create_batch_prediction_job_sample_test.py +++ b/samples/model-builder/create_batch_prediction_job_sample_test.py @@ -18,7 +18,7 @@ def test_create_batch_prediction_job_sample( - mock_sdk_init, mock_init_model, mock_batch_predict_model + mock_sdk_init, mock_model, mock_init_model, mock_batch_predict_model ): create_batch_prediction_job_sample.create_batch_prediction_job_sample( diff --git a/samples/model-builder/create_endpoint_sample.py b/samples/model-builder/create_endpoint_sample.py new file mode 100644 index 0000000000..fa3762bd57 --- /dev/null +++ b/samples/model-builder/create_endpoint_sample.py @@ -0,0 +1,34 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_endpoint_sample] +def create_endpoint_sample( + project: str, display_name: str, location: str, sync: bool = True, +): + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint.create( + display_name=display_name, project=project, location=location, + ) + + print(endpoint.display_name) + print(endpoint.resource_name) + print(endpoint.uri) + return endpoint + + +# [END aiplatform_sdk_create_endpoint_sample] diff --git a/samples/model-builder/create_endpoint_sample_test.py b/samples/model-builder/create_endpoint_sample_test.py new file mode 100644 index 0000000000..af3631d93a --- /dev/null +++ b/samples/model-builder/create_endpoint_sample_test.py @@ -0,0 +1,36 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_endpoint_sample +import test_constants as constants + + +def test_create_endpoint_sample(mock_sdk_init, mock_create_endpoint): + + create_endpoint_sample.create_endpoint_sample( + project=constants.PROJECT, + display_name=constants.DISPLAY_NAME, + location=constants.LOCATION, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_create_endpoint.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + project=constants.PROJECT, + location=constants.LOCATION, + ) diff --git a/samples/model-builder/create_training_pipeline_custom_job_sample.py b/samples/model-builder/create_training_pipeline_custom_job_sample.py new file mode 100644 index 0000000000..cea8d25cde --- /dev/null +++ b/samples/model-builder/create_training_pipeline_custom_job_sample.py @@ -0,0 +1,69 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_custom_job_sample] +def create_training_pipeline_custom_job_sample( + project: str, + location: str, + display_name: str, + script_path: str, + container_uri: str, + model_serving_container_image_uri: str, + model_display_name: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.CustomTrainingJob( + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + model_serving_container_image_uri=model_serving_container_image_uri, + ) + + model = job.run( + model_display_name=model_display_name, + args=args, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_custom_job_sample] diff --git a/samples/model-builder/create_training_pipeline_custom_job_test.py b/samples/model-builder/create_training_pipeline_custom_job_test.py new file mode 100644 index 0000000000..f01fd94a9a --- /dev/null +++ b/samples/model-builder/create_training_pipeline_custom_job_test.py @@ -0,0 +1,62 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_custom_job_sample +import test_constants as constants + + +def test_create_training_pipeline_custom_job_sample( + mock_sdk_init, mock_init_custom_training_job, mock_run_custom_training_job, +): + + create_training_pipeline_custom_job_sample.create_training_pipeline_custom_job_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + args=constants.ARGS, + script_path=constants.SCRIPT_PATH, + container_uri=constants.CONTAINER_URI, + model_serving_container_image_uri=constants.CONTAINER_URI, + model_display_name=constants.DISPLAY_NAME_2, + replica_count=constants.REPLICA_COUNT, + machine_type=constants.MACHINE_TYPE, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_init_custom_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + script_path=constants.SCRIPT_PATH, + container_uri=constants.CONTAINER_URI, + model_serving_container_image_uri=constants.CONTAINER_URI, + ) + mock_run_custom_training_job.assert_called_once_with( + model_display_name=constants.DISPLAY_NAME_2, + replica_count=constants.REPLICA_COUNT, + machine_type=constants.MACHINE_TYPE, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + args=constants.ARGS, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py b/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py new file mode 100644 index 0000000000..7d7dc6357c --- /dev/null +++ b/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py @@ -0,0 +1,73 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_custom_job_sample] +def create_training_pipeline_custom_training_managed_dataset_sample( + project: str, + location: str, + display_name: str, + script_path: str, + container_uri: str, + model_serving_container_image_uri: str, + dataset_id: int, + model_display_name: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.CustomTrainingJob( + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + model_serving_container_image_uri=model_serving_container_image_uri, + ) + + my_image_ds = aiplatform.ImageDataset(dataset_id) + + model = job.run( + dataset=my_image_ds, + model_display_name=model_display_name, + args=args, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_custom_job_sample] diff --git a/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py b/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py new file mode 100644 index 0000000000..4197f658b1 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py @@ -0,0 +1,70 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_custom_training_managed_dataset_sample +import test_constants as constants + + +def test_create_training_pipeline_custom_job_sample( + mock_sdk_init, + mock_image_dataset, + mock_init_custom_training_job, + mock_run_custom_training_job, + mock_get_image_dataset, +): + + create_training_pipeline_custom_training_managed_dataset_sample.create_training_pipeline_custom_training_managed_dataset_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + args=constants.ARGS, + script_path=constants.SCRIPT_PATH, + container_uri=constants.CONTAINER_URI, + model_serving_container_image_uri=constants.CONTAINER_URI, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + replica_count=constants.REPLICA_COUNT, + machine_type=constants.MACHINE_TYPE, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + ) + + mock_get_image_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_init_custom_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + script_path=constants.SCRIPT_PATH, + container_uri=constants.CONTAINER_URI, + model_serving_container_image_uri=constants.CONTAINER_URI, + ) + mock_run_custom_training_job.assert_called_once_with( + dataset=mock_image_dataset, + model_display_name=constants.DISPLAY_NAME_2, + args=constants.ARGS, + replica_count=constants.REPLICA_COUNT, + machine_type=constants.MACHINE_TYPE, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample.py b/samples/model-builder/create_training_pipeline_image_classification_sample.py index 050d40af82..3786894a05 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from google.cloud import aiplatform # [START aiplatform_sdk_create_training_pipeline_image_classification_sample] def create_training_pipeline_image_classification_sample( project: str, + location: str, display_name: str, dataset_id: int, - location: str = "us-central1", - model_display_name: str = None, + model_display_name: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py index c49e0e5f05..1c7080e7a1 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py @@ -20,13 +20,14 @@ def test_create_training_pipeline_image_classification_sample( mock_sdk_init, mock_image_dataset, - mock_init_automl_image_training_job, + mock_get_automl_image_training_job, mock_run_automl_image_training_job, mock_get_image_dataset, ): create_training_pipeline_image_classification_sample.create_training_pipeline_image_classification_sample( project=constants.PROJECT, + location=constants.LOCATION, display_name=constants.DISPLAY_NAME, dataset_id=constants.RESOURCE_ID, model_display_name=constants.DISPLAY_NAME_2, @@ -42,7 +43,7 @@ def test_create_training_pipeline_image_classification_sample( mock_sdk_init.assert_called_once_with( project=constants.PROJECT, location=constants.LOCATION ) - mock_init_automl_image_training_job.assert_called_once_with( + mock_get_automl_image_training_job.assert_called_once_with( display_name=constants.DISPLAY_NAME ) mock_run_automl_image_training_job.assert_called_once_with( diff --git a/samples/model-builder/create_training_pipeline_tabular_classification_sample.py b/samples/model-builder/create_training_pipeline_tabular_classification_sample.py new file mode 100644 index 0000000000..6bd9405383 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_tabular_classification_sample.py @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_tabular_classification_sample] +def create_training_pipeline_tabular_classification_sample( + project: str, + display_name: str, + dataset_id: int, + location: str = "us-central1", + model_display_name: str = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + tabular_classification_job = aiplatform.AutoMLTabularTrainingJob( + display_name=display_name, + ) + + my_tabular_dataset = aiplatform.TabularDataset(dataset_id) + + model = tabular_classification_job.run( + dataset=my_tabular_dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + model_display_name=model_display_name, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_tabular_classification_sample] diff --git a/samples/model-builder/create_training_pipeline_tabular_classification_sample_test.py b/samples/model-builder/create_training_pipeline_tabular_classification_sample_test.py new file mode 100644 index 0000000000..c015e99785 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_tabular_classification_sample_test.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_tabular_classification_sample +import test_constants as constants + + +def test_create_training_pipeline_tabular_classification_sample( + mock_sdk_init, + mock_tabular_dataset, + mock_get_automl_tabular_training_job, + mock_run_automl_tabular_training_job, + mock_get_tabular_dataset, +): + + create_training_pipeline_tabular_classification_sample.create_training_pipeline_tabular_classification_sample( + project=constants.PROJECT, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_tabular_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_tabular_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + ) + mock_run_automl_tabular_training_job.assert_called_once_with( + dataset=mock_tabular_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_tabular_regression_sample.py b/samples/model-builder/create_training_pipeline_tabular_regression_sample.py new file mode 100644 index 0000000000..2404bb37e2 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_tabular_regression_sample.py @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_tabular_regression_sample] +def create_training_pipeline_tabular_regression_sample( + project: str, + display_name: str, + dataset_id: int, + location: str = "us-central1", + model_display_name: str = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + tabular_regression_job = aiplatform.AutoMLTabularTrainingJob( + display_name=display_name, + ) + + my_tabular_dataset = aiplatform.TabularDataset(dataset_id) + + model = tabular_regression_job.run( + dataset=my_tabular_dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + model_display_name=model_display_name, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_tabular_regression_sample] diff --git a/samples/model-builder/create_training_pipeline_tabular_regression_sample_test.py b/samples/model-builder/create_training_pipeline_tabular_regression_sample_test.py new file mode 100644 index 0000000000..1e897b5851 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_tabular_regression_sample_test.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_tabular_regression_sample +import test_constants as constants + + +def test_create_training_pipeline_tabular_regression_sample( + mock_sdk_init, + mock_tabular_dataset, + mock_get_automl_tabular_training_job, + mock_run_automl_tabular_training_job, + mock_get_tabular_dataset, +): + + create_training_pipeline_tabular_regression_sample.create_training_pipeline_tabular_regression_sample( + project=constants.PROJECT, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_tabular_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_tabular_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + ) + mock_run_automl_tabular_training_job.assert_called_once_with( + dataset=mock_tabular_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/deploy_model_with_automatic_resources_sample.py b/samples/model-builder/deploy_model_with_automatic_resources_sample.py new file mode 100644 index 0000000000..27976ae652 --- /dev/null +++ b/samples/model-builder/deploy_model_with_automatic_resources_sample.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence, Tuple + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_deploy_model_with_automatic_resources_sample] +def deploy_model_with_automatic_resources_sample( + project, + location, + model_name: str, + endpoint: Optional[aiplatform.Endpoint] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + min_replica_count: int = 1, + max_replica_count: int = 1, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync: bool = True, +): + + aiplatform.init(project=project, location=location) + + model = aiplatform.Model(model_name=model_name) + + model.deploy( + endpoint=endpoint, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + metadata=metadata, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + return model + + +# [END aiplatform_sdk_deploy_model_with_automatic_resources_sample] diff --git a/samples/model-builder/deploy_model_with_automatic_resources_test.py b/samples/model-builder/deploy_model_with_automatic_resources_test.py new file mode 100644 index 0000000000..fff08b6e7e --- /dev/null +++ b/samples/model-builder/deploy_model_with_automatic_resources_test.py @@ -0,0 +1,52 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import deploy_model_with_automatic_resources_sample +import test_constants as constants + + +def test_deploy_model_with_automatic_resources_sample( + mock_sdk_init, mock_model, mock_init_model, mock_deploy_model, +): + + deploy_model_with_automatic_resources_sample.deploy_model_with_automatic_resources_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_name=constants.MODEL_NAME, + endpoint=constants.ENDPOINT_NAME, + deployed_model_display_name=constants.DEPLOYED_MODEL_DISPLAY_NAME, + traffic_percentage=constants.TRAFFIC_PERCENTAGE, + traffic_split=constants.TRAFFIC_SPLIT, + min_replica_count=constants.MIN_REPLICA_COUNT, + max_replica_count=constants.MAX_REPLICA_COUNT, + metadata=constants.ENDPOINT_DEPLOY_METADATA, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_init_model.assert_called_once_with(model_name=constants.MODEL_NAME) + + mock_deploy_model.assert_called_once_with( + endpoint=constants.ENDPOINT_NAME, + deployed_model_display_name=constants.DEPLOYED_MODEL_DISPLAY_NAME, + traffic_percentage=constants.TRAFFIC_PERCENTAGE, + traffic_split=constants.TRAFFIC_SPLIT, + min_replica_count=constants.MIN_REPLICA_COUNT, + max_replica_count=constants.MAX_REPLICA_COUNT, + metadata=constants.ENDPOINT_DEPLOY_METADATA, + sync=True, + ) diff --git a/samples/model-builder/deploy_model_with_dedicated_resources_sample.py b/samples/model-builder/deploy_model_with_dedicated_resources_sample.py new file mode 100644 index 0000000000..093dfae805 --- /dev/null +++ b/samples/model-builder/deploy_model_with_dedicated_resources_sample.py @@ -0,0 +1,70 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence, Tuple + +from google.cloud import aiplatform +from google.cloud.aiplatform import explain + + +# [START aiplatform_sdk_deploy_model_with_dedicated_resources_sample] +def deploy_model_with_dedicated_resources_sample( + project, + location, + model_name: str, + machine_type: str, + endpoint: Optional[aiplatform.Endpoint] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + min_replica_count: int = 1, + max_replica_count: int = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync: bool = True, +): + + aiplatform.init(project=project, location=location) + + model = aiplatform.Model(model_name=model_name) + + # The explanation_metadata and explanation_parameters should only be + # provided for a custom trained model and not an AutoML model. + model.deploy( + endpoint=endpoint, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + return model + + +# [END aiplatform_sdk_deploy_model_with_dedicated_resources_sample] diff --git a/samples/model-builder/deploy_model_with_dedicated_resources_test.py b/samples/model-builder/deploy_model_with_dedicated_resources_test.py new file mode 100644 index 0000000000..6dac9ad6e3 --- /dev/null +++ b/samples/model-builder/deploy_model_with_dedicated_resources_test.py @@ -0,0 +1,62 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import deploy_model_with_dedicated_resources_sample +import test_constants as constants + + +def test_deploy_model_with_dedicated_resources_sample( + mock_sdk_init, mock_model, mock_init_model, mock_deploy_model +): + + deploy_model_with_dedicated_resources_sample.deploy_model_with_dedicated_resources_sample( + project=constants.PROJECT, + location=constants.LOCATION, + machine_type=constants.MACHINE_TYPE, + model_name=constants.MODEL_NAME, + endpoint=constants.ENDPOINT_NAME, + deployed_model_display_name=constants.DEPLOYED_MODEL_DISPLAY_NAME, + traffic_percentage=constants.TRAFFIC_PERCENTAGE, + traffic_split=constants.TRAFFIC_SPLIT, + min_replica_count=constants.MIN_REPLICA_COUNT, + max_replica_count=constants.MAX_REPLICA_COUNT, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + metadata=constants.ENDPOINT_DEPLOY_METADATA, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_init_model.assert_called_once_with(model_name=constants.MODEL_NAME) + + mock_deploy_model.assert_called_once_with( + endpoint=constants.ENDPOINT_NAME, + deployed_model_display_name=constants.DEPLOYED_MODEL_DISPLAY_NAME, + traffic_percentage=constants.TRAFFIC_PERCENTAGE, + traffic_split=constants.TRAFFIC_SPLIT, + machine_type=constants.MACHINE_TYPE, + min_replica_count=constants.MIN_REPLICA_COUNT, + max_replica_count=constants.MAX_REPLICA_COUNT, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + metadata=constants.ENDPOINT_DEPLOY_METADATA, + sync=True, + ) diff --git a/samples/model-builder/get_model_sample.py b/samples/model-builder/get_model_sample.py new file mode 100644 index 0000000000..e5dff928ac --- /dev/null +++ b/samples/model-builder/get_model_sample.py @@ -0,0 +1,31 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_get_model_sample] +def get_model_sample(project: str, location: str, model_name: str): + + aiplatform.init(project=project, location=location) + + model = aiplatform.Model(model_name=model_name) + + print(model.display_name) + print(model.resource_name) + return model + + +# [END aiplatform_sdk_get_model_sample] diff --git a/samples/model-builder/get_model_test.py b/samples/model-builder/get_model_test.py new file mode 100644 index 0000000000..4bb5f5fddb --- /dev/null +++ b/samples/model-builder/get_model_test.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import get_model_sample +import test_constants as constants + + +def test_get_model_sample(mock_sdk_init, mock_init_model): + + get_model_sample.get_model_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_name=constants.MODEL_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_init_model.assert_called_once_with(model_name=constants.MODEL_NAME) diff --git a/samples/model-builder/import_data_text_entity_extraction_sample_test.py b/samples/model-builder/import_data_text_entity_extraction_sample_test.py index a3b93e9200..44ce9cc328 100644 --- a/samples/model-builder/import_data_text_entity_extraction_sample_test.py +++ b/samples/model-builder/import_data_text_entity_extraction_sample_test.py @@ -34,9 +34,7 @@ def test_import_data_text_entity_extraction_sample( project=constants.PROJECT, location=constants.LOCATION ) - mock_get_text_dataset.assert_called_once_with( - constants.DATASET_NAME, - ) + mock_get_text_dataset.assert_called_once_with(constants.DATASET_NAME,) mock_import_text_dataset.assert_called_once_with( gcs_source=constants.GCS_SOURCES, diff --git a/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py b/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py index 2134d66b35..8bfd6ac0c3 100644 --- a/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py +++ b/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py @@ -34,9 +34,7 @@ def test_import_data_text_sentiment_analysis_sample( project=constants.PROJECT, location=constants.LOCATION ) - mock_get_text_dataset.assert_called_once_with( - constants.DATASET_NAME, - ) + mock_get_text_dataset.assert_called_once_with(constants.DATASET_NAME,) mock_import_text_dataset.assert_called_once_with( gcs_source=constants.GCS_SOURCES, diff --git a/samples/model-builder/import_data_video_action_recognition_sample.py b/samples/model-builder/import_data_video_action_recognition_sample.py new file mode 100644 index 0000000000..dbdab59215 --- /dev/null +++ b/samples/model-builder/import_data_video_action_recognition_sample.py @@ -0,0 +1,45 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_video_action_recognition_sample] +def import_data_video_action_recognition_sample( + project: str, + location: str, + dataset_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.VideoDataset(dataset_name=dataset_name) + + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.video.action_recognition, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_video_action_recognition_sample] diff --git a/samples/model-builder/import_data_video_action_recognition_sample_test.py b/samples/model-builder/import_data_video_action_recognition_sample_test.py new file mode 100644 index 0000000000..f8eef996f2 --- /dev/null +++ b/samples/model-builder/import_data_video_action_recognition_sample_test.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.aiplatform import schema + +import pytest + +import import_data_video_action_recognition_sample + +import test_constants as constants + + +@pytest.mark.usefixtures("mock_get_video_dataset") +def test_import_data_video_action_recognition_sample( + mock_sdk_init, mock_import_video_data +): + + import_data_video_action_recognition_sample.import_data_video_action_recognition_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset_name=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_import_video_data.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.video.action_recognition, + sync=True, + ) diff --git a/samples/model-builder/import_data_video_classification_sample.py b/samples/model-builder/import_data_video_classification_sample.py new file mode 100644 index 0000000000..1e7243746c --- /dev/null +++ b/samples/model-builder/import_data_video_classification_sample.py @@ -0,0 +1,46 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_video_classification_sample] +def import_data_video_classification_sample( + project: str, + location: str, + dataset_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.VideoDataset(dataset_name=dataset_name) + + print(ds.display_name) + print(ds.resource_name) + + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.video.classification, + sync=sync, + ) + + ds.wait() + + return ds + + +# [END aiplatform_sdk_import_data_video_classification_sample] diff --git a/samples/model-builder/import_data_video_classification_sample_test.py b/samples/model-builder/import_data_video_classification_sample_test.py new file mode 100644 index 0000000000..cce5c0abd6 --- /dev/null +++ b/samples/model-builder/import_data_video_classification_sample_test.py @@ -0,0 +1,42 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.aiplatform import schema + +import pytest + +import import_data_video_classification_sample + +import test_constants as constants + + +@pytest.mark.usefixtures("mock_get_video_dataset") +def test_import_data_video_classification_sample(mock_sdk_init, mock_import_video_data): + + import_data_video_classification_sample.import_data_video_classification_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset_name=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_import_video_data.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.video.classification, + sync=True, + ) diff --git a/samples/model-builder/import_data_video_object_tracking_sample.py b/samples/model-builder/import_data_video_object_tracking_sample.py new file mode 100644 index 0000000000..ae38748b89 --- /dev/null +++ b/samples/model-builder/import_data_video_object_tracking_sample.py @@ -0,0 +1,45 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_video_object_tracking_sample] +def import_data_video_object_tracking_sample( + project: str, + location: str, + dataset_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.VideoDataset(dataset_name=dataset_name) + + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.video.object_tracking, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_video_object_tracking_sample] diff --git a/samples/model-builder/import_data_video_object_tracking_sample_test.py b/samples/model-builder/import_data_video_object_tracking_sample_test.py new file mode 100644 index 0000000000..2f94bc7dac --- /dev/null +++ b/samples/model-builder/import_data_video_object_tracking_sample_test.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.aiplatform import schema + +import pytest + +import import_data_video_object_tracking_sample + +import test_constants as constants + + +@pytest.mark.usefixtures("mock_get_video_dataset") +def test_import_data_video_object_tracking_sample( + mock_sdk_init, mock_import_video_data +): + + import_data_video_object_tracking_sample.import_data_video_object_tracking_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset_name=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_import_video_data.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.video.object_tracking, + sync=True, + ) diff --git a/samples/model-builder/predict_tabular_classification_sample.py b/samples/model-builder/predict_tabular_classification_sample.py new file mode 100644 index 0000000000..e5b1a0283d --- /dev/null +++ b/samples/model-builder/predict_tabular_classification_sample.py @@ -0,0 +1,35 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, List + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_tabular_classification_sample] +def predict_tabular_classification_sample( + project: str, location: str, endpoint: str, instances: List[Dict], +): + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint) + + response = endpoint.predict(instances=instances) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_tabular_classification_sample] diff --git a/samples/model-builder/predict_tabular_classification_sample_test.py b/samples/model-builder/predict_tabular_classification_sample_test.py new file mode 100644 index 0000000000..49a701115b --- /dev/null +++ b/samples/model-builder/predict_tabular_classification_sample_test.py @@ -0,0 +1,33 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_tabular_classification_sample +import test_constants as constants + + +def test_predict_tabular_classification_sample(mock_sdk_init, mock_get_endpoint): + + predict_tabular_classification_sample.predict_tabular_classification_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint=constants.ENDPOINT_NAME, + instances=constants.PREDICTION_TABULAR_CLASSIFICATION_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) diff --git a/samples/model-builder/predict_tabular_regression_sample.py b/samples/model-builder/predict_tabular_regression_sample.py new file mode 100644 index 0000000000..fee4d34e38 --- /dev/null +++ b/samples/model-builder/predict_tabular_regression_sample.py @@ -0,0 +1,34 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_tabular_regression_sample] +def predict_tabular_regression_sample( + project: str, location: str, endpoint: str, instances: List[Dict], +): + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint) + + response = endpoint.predict(instances=instances) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_tabular_regression_sample] diff --git a/samples/model-builder/predict_tabular_regression_sample_test.py b/samples/model-builder/predict_tabular_regression_sample_test.py new file mode 100644 index 0000000000..7491d7c1d5 --- /dev/null +++ b/samples/model-builder/predict_tabular_regression_sample_test.py @@ -0,0 +1,33 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_tabular_regression_sample +import test_constants as constants + + +def test_predict_tabular_regression_sample(mock_sdk_init, mock_get_endpoint): + + predict_tabular_regression_sample.predict_tabular_regression_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint=constants.ENDPOINT_NAME, + instances=constants.PREDICTION_TABULAR_REGRESSOIN_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) diff --git a/samples/model-builder/predict_text_classification_single_label_sample_test.py b/samples/model-builder/predict_text_classification_single_label_sample_test.py index c446235a79..789f2962c3 100644 --- a/samples/model-builder/predict_text_classification_single_label_sample_test.py +++ b/samples/model-builder/predict_text_classification_single_label_sample_test.py @@ -32,6 +32,4 @@ def test_predict_text_classification_single_label_sample( project=constants.PROJECT, location=constants.LOCATION ) - mock_get_endpoint.assert_called_once_with( - constants.ENDPOINT_NAME, - ) + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) diff --git a/samples/model-builder/predict_text_entity_extraction_sample_test.py b/samples/model-builder/predict_text_entity_extraction_sample_test.py index 3ca2b49b43..3b123ff148 100644 --- a/samples/model-builder/predict_text_entity_extraction_sample_test.py +++ b/samples/model-builder/predict_text_entity_extraction_sample_test.py @@ -30,6 +30,4 @@ def test_predict_text_entity_extraction_sample(mock_sdk_init, mock_get_endpoint) project=constants.PROJECT, location=constants.LOCATION ) - mock_get_endpoint.assert_called_once_with( - constants.ENDPOINT_NAME, - ) + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) diff --git a/samples/model-builder/predict_text_sentiment_analysis_sample_test.py b/samples/model-builder/predict_text_sentiment_analysis_sample_test.py index c2ed180c9f..e3a3fad58c 100644 --- a/samples/model-builder/predict_text_sentiment_analysis_sample_test.py +++ b/samples/model-builder/predict_text_sentiment_analysis_sample_test.py @@ -30,6 +30,4 @@ def test_predict_text_sentiment_analysis_sample(mock_sdk_init, mock_get_endpoint project=constants.PROJECT, location=constants.LOCATION ) - mock_get_endpoint.assert_called_once_with( - constants.ENDPOINT_NAME, - ) + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 50dfa968b4..994a8724ee 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -16,6 +16,7 @@ from uuid import uuid4 from google.auth import credentials +from google.cloud import aiplatform PROJECT = "abc" LOCATION = "us-central1" @@ -39,7 +40,10 @@ MODEL_NAME = f"{PARENT}/models/{RESOURCE_ID}" TRAINING_JOB_NAME = f"{PARENT}/trainingJobs/{RESOURCE_ID}" +BIGQUERY_SOURCE = f"bq://{PROJECT}.{DATASET_NAME}.table1" + GCS_SOURCES = ["gs://bucket1/source1.jsonl", "gs://bucket7/source4.jsonl"] +BIGQUERY_SOURCE = "bq://bigquery-public-data.ml_datasets.iris" GCS_DESTINATION = "gs://bucket3/output-dir/" TRAINING_FRACTION_SPLIT = 0.7 @@ -51,3 +55,97 @@ ENCRYPTION_SPEC_KEY_NAME = f"{PARENT}/keyRings/{RESOURCE_ID}/cryptoKeys/{RESOURCE_ID_2}" PREDICTION_TEXT_INSTANCE = "This is some text for testing NLP prediction output" + +PREDICTION_TABULAR_CLASSIFICATION_INSTANCE = [ + { + "petal_length": "1.4", + "petal_width": "1.3", + "sepal_length": "5.1", + "sepal_width": "2.8", + } +] +PREDICTION_TABULAR_REGRESSOIN_INSTANCE = [ + { + "BOOLEAN_2unique_NULLABLE": False, + "DATETIME_1unique_NULLABLE": "2019-01-01 00:00:00", + "DATE_1unique_NULLABLE": "2019-01-01", + "FLOAT_5000unique_NULLABLE": 1611, + "FLOAT_5000unique_REPEATED": [2320, 1192], + "INTEGER_5000unique_NULLABLE": "8", + "NUMERIC_5000unique_NULLABLE": 16, + "STRING_5000unique_NULLABLE": "str-2", + "STRUCT_NULLABLE": { + "BOOLEAN_2unique_NULLABLE": False, + "DATE_1unique_NULLABLE": "2019-01-01", + "DATETIME_1unique_NULLABLE": "2019-01-01 00:00:00", + "FLOAT_5000unique_NULLABLE": 1308, + "FLOAT_5000unique_REPEATED": [2323, 1178], + "FLOAT_5000unique_REQUIRED": 3089, + "INTEGER_5000unique_NULLABLE": "1777", + "NUMERIC_5000unique_NULLABLE": 3323, + "TIME_1unique_NULLABLE": "23:59:59.999999", + "STRING_5000unique_NULLABLE": "str-49", + "TIMESTAMP_1unique_NULLABLE": "1546387199999999", + }, + "TIMESTAMP_1unique_NULLABLE": "1546387199999999", + "TIME_1unique_NULLABLE": "23:59:59.999999", + } +] +SCRIPT_PATH = "task.py" +CONTAINER_URI = "gcr.io/my_project/my_image:latest" +ARGS = ["--tfds", "tf_flowers:3.*.*"] +REPLICA_COUNT = 0 +MACHINE_TYPE = "n1-standard-4" +ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" +ACCELERATOR_COUNT = 0 + +# Model constants +MODEL_RESOURCE_NAME = f"{PARENT}/models/1234" +MODEL_ARTIFACT_URI = "gs://bucket3/output-dir/" +SERVING_CONTAINER_IMAGE_URI = "http://gcr.io/test/test:latest" +SERVING_CONTAINER_IMAGE = "gcr.io/test-serving/container:image" +SERVING_CONTAINER_PREDICT_ROUTE = "predict" +SERVING_CONTAINER_HEALTH_ROUTE = "metadata" +DESCRIPTION = "test description" +SERVING_CONTAINER_COMMAND = ["python3", "run_my_model.py"] +SERVING_CONTAINER_ARGS = ["--test", "arg"] +SERVING_CONTAINER_ENVIRONMENT_VARIABLES = { + "learning_rate": 0.01, + "loss_fn": "mse", +} +SERVING_CONTAINER_PORTS = [8888, 10000] +INSTANCE_SCHEMA_URI = "gs://test/schema/instance.yaml" +PARAMETERS_SCHEMA_URI = "gs://test/schema/parameters.yaml" +PREDICTION_SCHEMA_URI = "gs://test/schema/predictions.yaml" + +MODEL_DESCRIPTION = "This is a model" +SERVING_CONTAINER_COMMAND = ["python3", "run_my_model.py"] +SERVING_CONTAINER_ARGS = ["--test", "arg"] +SERVING_CONTAINER_ENVIRONMENT_VARIABLES = { + "learning_rate": 0.01, + "loss_fn": "mse", +} + +SERVING_CONTAINER_PORTS = [8888, 10000] +EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( + inputs={ + "features": { + "input_tensor_name": "dense_input", + "encoding": "BAG_OF_FEATURES", + "modality": "numeric", + "index_feature_mapping": ["abc", "def", "ghj"], + } + }, + outputs={"medv": {"output_tensor_name": "dense_2"}}, +) +EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( + {"sampled_shapley_attribution": {"path_count": 10}} +) + +# Endpoint constants +DEPLOYED_MODEL_DISPLAY_NAME = "model_name" +TRAFFIC_PERCENTAGE = 80 +TRAFFIC_SPLIT = {"a": 99, "b": 1} +MIN_REPLICA_COUNT = 1 +MAX_REPLICA_COUNT = 1 +ENDPOINT_DEPLOY_METADATA = () diff --git a/samples/model-builder/upload_model_sample.py b/samples/model-builder/upload_model_sample.py new file mode 100644 index 0000000000..05cb910b12 --- /dev/null +++ b/samples/model-builder/upload_model_sample.py @@ -0,0 +1,71 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence + +from google.cloud import aiplatform +from google.cloud.aiplatform import explain + + +# [START aiplatform_sdk_upload_model_sample] +def upload_model_sample( + project: str, + location: str, + display_name: str, + serving_container_image_uri: str, + artifact_uri: Optional[str] = None, + serving_container_predict_route: Optional[str] = None, + serving_container_health_route: Optional[str] = None, + description: Optional[str] = None, + serving_container_command: Optional[Sequence[str]] = None, + serving_container_args: Optional[Sequence[str]] = None, + serving_container_environment_variables: Optional[Dict[str, str]] = None, + serving_container_ports: Optional[Sequence[int]] = None, + instance_schema_uri: Optional[str] = None, + parameters_schema_uri: Optional[str] = None, + prediction_schema_uri: Optional[str] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + sync: bool = True, +): + + aiplatform.init(project=project, location=location) + + model = aiplatform.Model.upload( + display_name=display_name, + artifact_uri=artifact_uri, + serving_container_image_uri=serving_container_image_uri, + serving_container_predict_route=serving_container_predict_route, + serving_container_health_route=serving_container_health_route, + instance_schema_uri=instance_schema_uri, + parameters_schema_uri=parameters_schema_uri, + prediction_schema_uri=prediction_schema_uri, + description=description, + serving_container_command=serving_container_command, + serving_container_args=serving_container_args, + serving_container_environment_variables=serving_container_environment_variables, + serving_container_ports=serving_container_ports, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + return model + + +# [END aiplatform_sdk_upload_model_sample] diff --git a/samples/model-builder/upload_model_test.py b/samples/model-builder/upload_model_test.py new file mode 100644 index 0000000000..ea00051c1a --- /dev/null +++ b/samples/model-builder/upload_model_test.py @@ -0,0 +1,62 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import test_constants as constants +import upload_model_sample + + +def test_upload_model_sample(mock_sdk_init, mock_upload_model): + + upload_model_sample.upload_model_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.MODEL_NAME, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + description=constants.MODEL_DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_upload_model.assert_called_once_with( + display_name=constants.MODEL_NAME, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + description=constants.MODEL_DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + sync=True, + ) diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index 8a6680087b..8c99f2e635 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,3 +1,3 @@ -pytest==6.2.2 +pytest==6.2.3 google-cloud-storage>=1.26.0, <2.0.0dev google-cloud-aiplatform==0.7.1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index ea74c89e5e..03c3f38667 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -85,6 +85,7 @@ _TEST_PREDICTION = [[1.0, 2.0, 3.0], [3.0, 3.0, 1.0]] _TEST_INSTANCES = [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]] _TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) +_TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com" _TEST_DEPLOYED_MODELS = [ gca_endpoint.DeployedModel(id=_TEST_ID, display_name=_TEST_DISPLAY_NAME), @@ -667,6 +668,7 @@ def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, sync=sync, ) @@ -687,6 +689,7 @@ def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): dedicated_resources=expected_dedicated_resources, model=test_model.resource_name, display_name=None, + service_account=_TEST_SERVICE_ACCOUNT, ) deploy_model_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 84498d0a37..fc0c33dbca 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -144,6 +144,11 @@ def test_create_client_user_agent(self): ("us-central1", None, "us-central1-aiplatform.googleapis.com"), ("us-central1", "europe-west4", "europe-west4-aiplatform.googleapis.com",), ("asia-east1", None, "asia-east1-aiplatform.googleapis.com"), + ( + "asia-southeast1", + "australia-southeast1", + "australia-southeast1-aiplatform.googleapis.com", + ), ], ) def test_get_client_options( diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 47b000d189..ad84fde65b 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -120,6 +120,7 @@ _TEST_PREDICTION_SCHEMA_URI = "gs://test/schema/predictions.yaml" _TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) +_TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com" _TEST_EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( inputs={ @@ -156,6 +157,39 @@ ) _TEST_OUTPUT_DIR = "gs://my-output-bucket" +_TEST_CONTAINER_REGISTRY_DESTINATION = ( + "us-central1-docker.pkg.dev/projectId/repoName/imageName" +) + +_TEST_EXPORT_FORMAT_ID_IMAGE = "custom-trained" +_TEST_EXPORT_FORMAT_ID_ARTIFACT = "tf-saved-model" + +_TEST_SUPPORTED_EXPORT_FORMATS_IMAGE = [ + gca_model.Model.ExportFormat( + id=_TEST_EXPORT_FORMAT_ID_IMAGE, + exportable_contents=[gca_model.Model.ExportFormat.ExportableContent.IMAGE], + ) +] + +_TEST_SUPPORTED_EXPORT_FORMATS_ARTIFACT = [ + gca_model.Model.ExportFormat( + id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + exportable_contents=[gca_model.Model.ExportFormat.ExportableContent.ARTIFACT], + ) +] + +_TEST_SUPPORTED_EXPORT_FORMATS_BOTH = [ + gca_model.Model.ExportFormat( + id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + exportable_contents=[ + gca_model.Model.ExportFormat.ExportableContent.ARTIFACT, + gca_model.Model.ExportFormat.ExportableContent.IMAGE, + ], + ) +] + +_TEST_SUPPORTED_EXPORT_FORMATS_UNSUPPORTED = [] +_TEST_CONTAINER_REGISTRY_DESTINATION @pytest.fixture @@ -218,6 +252,58 @@ def get_model_with_custom_project_mock(): yield get_model_mock +@pytest.fixture +def get_model_with_supported_export_formats_image(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME, + supported_export_formats=_TEST_SUPPORTED_EXPORT_FORMATS_IMAGE, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_supported_export_formats_artifact(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME, + supported_export_formats=_TEST_SUPPORTED_EXPORT_FORMATS_ARTIFACT, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_both_supported_export_formats(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME, + supported_export_formats=_TEST_SUPPORTED_EXPORT_FORMATS_BOTH, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_unsupported_export_formats(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME, + supported_export_formats=_TEST_SUPPORTED_EXPORT_FORMATS_UNSUPPORTED, + ) + yield get_model_mock + + @pytest.fixture def upload_model_mock(): with mock.patch.object( @@ -270,6 +356,22 @@ def upload_model_with_custom_location_mock(): yield upload_model_mock +@pytest.fixture +def export_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "export_model" + ) as export_model_mock: + export_model_lro_mock = mock.Mock(ga_operation.Operation) + export_model_lro_mock.metadata = gca_model_service.ExportModelOperationMetadata( + output_info=gca_model_service.ExportModelOperationMetadata.OutputInfo( + artifact_output_uri=_TEST_OUTPUT_DIR + ) + ) + export_model_lro_mock.result.return_value = None + export_model_mock.return_value = export_model_lro_mock + yield export_model_mock + + @pytest.fixture def delete_model_mock(): with mock.patch.object( @@ -405,7 +507,6 @@ def test_constructor_create_client_with_custom_location(self, create_client_mock def test_constructor_creates_client_with_custom_credentials( self, create_client_mock ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = auth_credentials.AnonymousCredentials() models.Model(_TEST_ID, credentials=creds) create_client_mock.assert_called_once_with( @@ -416,12 +517,10 @@ def test_constructor_creates_client_with_custom_credentials( ) def test_constructor_gets_model(self, get_model_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Model(_TEST_ID) get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) def test_constructor_gets_model_with_custom_project(self, get_model_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Model(_TEST_ID, project=_TEST_PROJECT_2) test_model_resource_name = model_service_client.ModelServiceClient.model_path( _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID @@ -429,7 +528,6 @@ def test_constructor_gets_model_with_custom_project(self, get_model_mock): get_model_mock.assert_called_once_with(name=test_model_resource_name) def test_constructor_gets_model_with_custom_location(self, get_model_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) models.Model(_TEST_ID, location=_TEST_LOCATION_2) test_model_resource_name = model_service_client.ModelServiceClient.model_path( _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID @@ -441,7 +539,6 @@ def test_upload_uploads_and_gets_model( self, upload_model_mock, get_model_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) my_model = models.Model.upload( display_name=_TEST_MODEL_NAME, serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, @@ -488,8 +585,6 @@ def test_upload_uploads_and_gets_model_with_all_args( self, upload_model_with_explanations_mock, get_model_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - my_model = models.Model.upload( display_name=_TEST_MODEL_NAME, artifact_uri=_TEST_ARTIFACT_URI, @@ -563,8 +658,6 @@ def test_upload_uploads_and_gets_model_with_custom_project( sync, ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - test_model_resource_name = model_service_client.ModelServiceClient.model_path( _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID ) @@ -611,7 +704,6 @@ def test_upload_uploads_and_gets_model_with_custom_location( get_model_with_custom_location_mock, sync, ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model_resource_name = model_service_client.ModelServiceClient.model_path( _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID ) @@ -715,6 +807,7 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, sync=sync, ) @@ -733,6 +826,7 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): dedicated_resources=expected_dedicated_resources, model=test_model.resource_name, display_name=None, + service_account=_TEST_SERVICE_ACCOUNT, ) deploy_model_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, @@ -748,7 +842,6 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): def test_deploy_no_endpoint_with_explanations( self, deploy_model_with_explanations_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model = models.Model(_TEST_ID) test_endpoint = test_model.deploy( machine_type=_TEST_MACHINE_TYPE, @@ -940,7 +1033,6 @@ def test_batch_predict_gcs_source_bq_dest( def test_batch_predict_with_all_args( self, create_batch_prediction_job_with_explanations_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model = models.Model(_TEST_ID) creds = auth_credentials.AnonymousCredentials() @@ -1099,7 +1191,6 @@ def test_delete_model(self, delete_model_mock, sync): @pytest.mark.usefixtures("get_model_mock") def test_print_model(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model = models.Model(_TEST_ID) assert ( repr(test_model) @@ -1108,7 +1199,6 @@ def test_print_model(self): @pytest.mark.usefixtures("get_model_mock") def test_print_model_if_waiting(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model = models.Model(_TEST_ID) test_model._gca_resource = None test_model._latest_future = futures.Future() @@ -1119,7 +1209,6 @@ def test_print_model_if_waiting(self): @pytest.mark.usefixtures("get_model_mock") def test_print_model_if_exception(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_model = models.Model(_TEST_ID) test_model._gca_resource = None mock_exception = Exception("mock exception") @@ -1128,3 +1217,170 @@ def test_print_model_if_exception(self): repr(test_model) == f"{object.__repr__(test_model)} failed with {str(mock_exception)}" ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_with_supported_export_formats_artifact") + def test_export_model_as_artifact(self, export_model_mock, sync): + test_model = models.Model(_TEST_ID) + + if not sync: + test_model.wait() + + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + artifact_destination=_TEST_OUTPUT_DIR, + ) + + expected_output_config = gca_model_service.ExportModelRequest.OutputConfig( + export_format_id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + artifact_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_OUTPUT_DIR + ), + ) + + export_model_mock.assert_called_once_with( + name=f"{_TEST_PARENT}/models/{_TEST_ID}", + output_config=expected_output_config, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_with_supported_export_formats_image") + def test_export_model_as_image(self, export_model_mock, sync): + test_model = models.Model(_TEST_ID) + + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_IMAGE, + image_destination=_TEST_CONTAINER_REGISTRY_DESTINATION, + ) + + if not sync: + test_model.wait() + + expected_output_config = gca_model_service.ExportModelRequest.OutputConfig( + export_format_id=_TEST_EXPORT_FORMAT_ID_IMAGE, + image_destination=gca_io.ContainerRegistryDestination( + output_uri=_TEST_CONTAINER_REGISTRY_DESTINATION + ), + ) + + export_model_mock.assert_called_once_with( + name=f"{_TEST_PARENT}/models/{_TEST_ID}", + output_config=expected_output_config, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_with_both_supported_export_formats") + def test_export_model_as_both_formats(self, export_model_mock, sync): + """Exports a 'tf-saved-model' as both an artifact and an image""" + + test_model = models.Model(_TEST_ID) + + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + image_destination=_TEST_CONTAINER_REGISTRY_DESTINATION, + artifact_destination=_TEST_OUTPUT_DIR, + ) + + if not sync: + test_model.wait() + + expected_output_config = gca_model_service.ExportModelRequest.OutputConfig( + export_format_id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + image_destination=gca_io.ContainerRegistryDestination( + output_uri=_TEST_CONTAINER_REGISTRY_DESTINATION + ), + artifact_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_OUTPUT_DIR + ), + ) + + export_model_mock.assert_called_once_with( + name=f"{_TEST_PARENT}/models/{_TEST_ID}", + output_config=expected_output_config, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_with_unsupported_export_formats") + def test_export_model_not_supported(self, export_model_mock, sync): + test_model = models.Model(_TEST_ID) + + with pytest.raises(ValueError) as e: + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_IMAGE, + image_destination=_TEST_CONTAINER_REGISTRY_DESTINATION, + ) + + if not sync: + test_model.wait() + + assert e.match( + regexp=f"The model `{_TEST_PARENT}/models/{_TEST_ID}` is not exportable." + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_with_supported_export_formats_image") + def test_export_model_as_image_with_invalid_args(self, export_model_mock, sync): + + # Passing an artifact destination on an image-only Model + with pytest.raises(ValueError) as dest_type_err: + test_model = models.Model(_TEST_ID) + + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_IMAGE, + artifact_destination=_TEST_OUTPUT_DIR, + sync=sync, + ) + + if not sync: + test_model.wait() + + # Passing no destination type + with pytest.raises(ValueError) as no_dest_err: + test_model = models.Model(_TEST_ID) + + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_IMAGE, sync=sync, + ) + + if not sync: + test_model.wait() + + # Passing an invalid export format ID + with pytest.raises(ValueError) as format_err: + test_model = models.Model(_TEST_ID) + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + image_destination=_TEST_CONTAINER_REGISTRY_DESTINATION, + sync=sync, + ) + + if not sync: + test_model.wait() + + assert dest_type_err.match( + regexp=r"This model can not be exported as an artifact." + ) + assert no_dest_err.match(regexp=r"Please provide an") + assert format_err.match( + regexp=f"'{_TEST_EXPORT_FORMAT_ID_ARTIFACT}' is not a supported export format" + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_with_supported_export_formats_artifact") + def test_export_model_as_artifact_with_invalid_args(self, export_model_mock, sync): + test_model = models.Model(_TEST_ID) + + # Passing an image destination on an artifact-only Model + with pytest.raises(ValueError) as e: + test_model.export_model( + export_format_id=_TEST_EXPORT_FORMAT_ID_ARTIFACT, + image_destination=_TEST_CONTAINER_REGISTRY_DESTINATION, + sync=sync, + ) + + if not sync: + test_model.wait() + + assert e.match( + regexp=r"This model can not be exported as a container image." + ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index b5520a5f4c..1a61469444 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -84,6 +84,7 @@ _TEST_ANNOTATION_SCHEMA_URI = schema.dataset.annotation.image.classification _TEST_BASE_OUTPUT_DIR = "gs://test-base-output-dir" +_TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com" _TEST_BIGQUERY_DESTINATION = "bq://test-project" _TEST_RUN_ARGS = ["-v", 0.1, "--test=arg"] _TEST_REPLICA_COUNT = 1 @@ -593,6 +594,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_from_job = job.run( dataset=mock_tabular_dataset, base_output_dir=_TEST_BASE_OUTPUT_DIR, + service_account=_TEST_SERVICE_ACCOUNT, args=_TEST_RUN_ARGS, replica_count=1, machine_type=_TEST_MACHINE_TYPE, @@ -688,6 +690,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( { "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + "serviceAccount": _TEST_SERVICE_ACCOUNT, }, struct_pb2.Value(), ), @@ -2501,6 +2504,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( dataset=mock_nontabular_dataset, annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, base_output_dir=_TEST_BASE_OUTPUT_DIR, + service_account=_TEST_SERVICE_ACCOUNT, args=_TEST_RUN_ARGS, replica_count=1, machine_type=_TEST_MACHINE_TYPE, @@ -2582,6 +2586,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( { "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + "serviceAccount": _TEST_SERVICE_ACCOUNT, }, struct_pb2.Value(), ), @@ -2930,6 +2935,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( dataset=mock_tabular_dataset, model_display_name=_TEST_MODEL_DISPLAY_NAME, base_output_dir=_TEST_BASE_OUTPUT_DIR, + service_account=_TEST_SERVICE_ACCOUNT, args=_TEST_RUN_ARGS, replica_count=1, machine_type=_TEST_MACHINE_TYPE, @@ -3011,6 +3017,157 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( ), ) + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + "serviceAccount": _TEST_SERVICE_ACCOUNT, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + # model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME + "-model", + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + true_training_pipeline = gca_training_pipeline.TrainingPipeline( display_name=_TEST_DISPLAY_NAME, training_task_definition=schema.training_job.definition.custom_task, diff --git a/tests/unit/gapic/__init__.py b/tests/unit/gapic/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 76634da6a7..d1b0b51231 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -1549,19 +1549,21 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - dataset = "mussel" + location = "mussel" + dataset = "winkle" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1571,21 +1573,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "scallop" - location = "abalone" - dataset = "squid" + project = "squid" + dataset = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", + "project": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py index e3f19d0271..db9a7d5367 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py @@ -44,7 +44,6 @@ from google.cloud.aiplatform_v1beta1.types import feature_selector from google.cloud.aiplatform_v1beta1.types import featurestore_online_service from google.oauth2 import service_account -from google.protobuf import timestamp_pb2 as timestamp # type: ignore def client_cert_source_callback(): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py index f5e67013c6..cffb5d0ade 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py @@ -5125,6 +5125,232 @@ async def test_batch_read_feature_values_flattened_error_async(): ) +def test_export_feature_values( + transport: str = "grpc", + request_type=featurestore_service.ExportFeatureValuesRequest, +): + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_feature_values), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ExportFeatureValuesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_feature_values_from_dict(): + test_export_feature_values(request_type=dict) + + +def test_export_feature_values_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_feature_values), "__call__" + ) as call: + client.export_feature_values() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ExportFeatureValuesRequest() + + +@pytest.mark.asyncio +async def test_export_feature_values_async( + transport: str = "grpc_asyncio", + request_type=featurestore_service.ExportFeatureValuesRequest, +): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_feature_values), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == featurestore_service.ExportFeatureValuesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_export_feature_values_async_from_dict(): + await test_export_feature_values_async(request_type=dict) + + +def test_export_feature_values_field_headers(): + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ExportFeatureValuesRequest() + request.entity_type = "entity_type/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_feature_values), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_export_feature_values_field_headers_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = featurestore_service.ExportFeatureValuesRequest() + request.entity_type = "entity_type/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_feature_values), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "entity_type=entity_type/value",) in kw["metadata"] + + +def test_export_feature_values_flattened(): + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_feature_values), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_feature_values(entity_type="entity_type_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == "entity_type_value" + + +def test_export_feature_values_flattened_error(): + client = FeaturestoreServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_feature_values( + featurestore_service.ExportFeatureValuesRequest(), + entity_type="entity_type_value", + ) + + +@pytest.mark.asyncio +async def test_export_feature_values_flattened_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_feature_values), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_feature_values(entity_type="entity_type_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].entity_type == "entity_type_value" + + +@pytest.mark.asyncio +async def test_export_feature_values_flattened_error_async(): + client = FeaturestoreServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_feature_values( + featurestore_service.ExportFeatureValuesRequest(), + entity_type="entity_type_value", + ) + + def test_search_features( transport: str = "grpc", request_type=featurestore_service.SearchFeaturesRequest ): @@ -5592,6 +5818,7 @@ def test_featurestore_service_base_transport(): "delete_feature", "import_feature_values", "batch_read_feature_values", + "export_feature_values", "search_features", ) for method in methods: diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py index 5e8e860b32..9580632c24 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -2782,35 +2782,8 @@ def test_parse_index_endpoint_path(): assert expected == actual -def test_index_endpoint_path(): - project = "squid" - location = "clam" - index_endpoint = "whelk" - - expected = "projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}".format( - project=project, location=location, index_endpoint=index_endpoint, - ) - actual = IndexEndpointServiceClient.index_endpoint_path( - project, location, index_endpoint - ) - assert expected == actual - - -def test_parse_index_endpoint_path(): - expected = { - "project": "octopus", - "location": "oyster", - "index_endpoint": "nudibranch", - } - path = IndexEndpointServiceClient.index_endpoint_path(**expected) - - # Check that the path construction is reversible. - actual = IndexEndpointServiceClient.parse_index_endpoint_path(path) - assert expected == actual - - def test_common_billing_account_path(): - billing_account = "cuttlefish" + billing_account = "squid" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, @@ -2821,7 +2794,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "clam", } path = IndexEndpointServiceClient.common_billing_account_path(**expected) @@ -2831,7 +2804,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "winkle" + folder = "whelk" expected = "folders/{folder}".format(folder=folder,) actual = IndexEndpointServiceClient.common_folder_path(folder) @@ -2840,7 +2813,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "octopus", } path = IndexEndpointServiceClient.common_folder_path(**expected) @@ -2850,7 +2823,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "scallop" + organization = "oyster" expected = "organizations/{organization}".format(organization=organization,) actual = IndexEndpointServiceClient.common_organization_path(organization) @@ -2859,7 +2832,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "nudibranch", } path = IndexEndpointServiceClient.common_organization_path(**expected) @@ -2869,7 +2842,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "squid" + project = "cuttlefish" expected = "projects/{project}".format(project=project,) actual = IndexEndpointServiceClient.common_project_path(project) @@ -2878,7 +2851,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "mussel", } path = IndexEndpointServiceClient.common_project_path(**expected) @@ -2888,8 +2861,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "whelk" - location = "octopus" + project = "winkle" + location = "nautilus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -2900,8 +2873,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "scallop", + "location": "abalone", } path = IndexEndpointServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index c8e506d54b..6acb3e7b86 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -8846,11 +8846,59 @@ def test_parse_model_deployment_monitoring_job_path(): assert expected == actual -def test_trial_path(): +def test_network_path(): project = "squid" - location = "clam" - study = "whelk" - trial = "octopus" + network = "clam" + + expected = "projects/{project}/global/networks/{network}".format( + project=project, network=network, + ) + actual = JobServiceClient.network_path(project, network) + assert expected == actual + + +def test_parse_network_path(): + expected = { + "project": "whelk", + "network": "octopus", + } + path = JobServiceClient.network_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_network_path(path) + assert expected == actual + + +def test_tensorboard_path(): + project = "oyster" + location = "nudibranch" + tensorboard = "cuttlefish" + + expected = "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( + project=project, location=location, tensorboard=tensorboard, + ) + actual = JobServiceClient.tensorboard_path(project, location, tensorboard) + assert expected == actual + + +def test_parse_tensorboard_path(): + expected = { + "project": "mussel", + "location": "winkle", + "tensorboard": "nautilus", + } + path = JobServiceClient.tensorboard_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_tensorboard_path(path) + assert expected == actual + + +def test_trial_path(): + project = "scallop" + location = "abalone" + study = "squid" + trial = "clam" expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( project=project, location=location, study=study, trial=trial, @@ -8861,10 +8909,10 @@ def test_trial_path(): def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", + "project": "whelk", + "location": "octopus", + "study": "oyster", + "trial": "nudibranch", } path = JobServiceClient.trial_path(**expected) @@ -8874,7 +8922,7 @@ def test_parse_trial_path(): def test_common_billing_account_path(): - billing_account = "winkle" + billing_account = "cuttlefish" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, @@ -8885,7 +8933,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "mussel", } path = JobServiceClient.common_billing_account_path(**expected) @@ -8895,7 +8943,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "scallop" + folder = "winkle" expected = "folders/{folder}".format(folder=folder,) actual = JobServiceClient.common_folder_path(folder) @@ -8904,7 +8952,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "nautilus", } path = JobServiceClient.common_folder_path(**expected) @@ -8914,7 +8962,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "squid" + organization = "scallop" expected = "organizations/{organization}".format(organization=organization,) actual = JobServiceClient.common_organization_path(organization) @@ -8923,7 +8971,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "abalone", } path = JobServiceClient.common_organization_path(**expected) @@ -8933,7 +8981,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "whelk" + project = "squid" expected = "projects/{project}".format(project=project,) actual = JobServiceClient.common_project_path(project) @@ -8942,7 +8990,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "clam", } path = JobServiceClient.common_project_path(**expected) @@ -8952,8 +9000,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "oyster" - location = "nudibranch" + project = "whelk" + location = "octopus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -8964,8 +9012,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "oyster", + "location": "nudibranch", } path = JobServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index b9c944280d..45fd76e099 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -736,7 +736,9 @@ def test_get_metadata_store( type(client.transport.get_metadata_store), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_store.MetadataStore(name="name_value",) + call.return_value = metadata_store.MetadataStore( + name="name_value", description="description_value", + ) response = client.get_metadata_store(request) @@ -752,6 +754,8 @@ def test_get_metadata_store( assert response.name == "name_value" + assert response.description == "description_value" + def test_get_metadata_store_from_dict(): test_get_metadata_store(request_type=dict) @@ -794,7 +798,9 @@ async def test_get_metadata_store_async( ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_store.MetadataStore(name="name_value",) + metadata_store.MetadataStore( + name="name_value", description="description_value", + ) ) response = await client.get_metadata_store(request) @@ -810,6 +816,8 @@ async def test_get_metadata_store_async( assert response.name == "name_value" + assert response.description == "description_value" + @pytest.mark.asyncio async def test_get_metadata_store_async_from_dict(): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index f547beb6bf..221f22f654 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -1551,21 +1551,19 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - location = "mussel" - dataset = "winkle" + dataset = "mussel" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1575,19 +1573,21 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "squid" - dataset = "clam" + project = "scallop" + location = "abalone" + dataset = "squid" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", + "project": "clam", + "location": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index e353077d80..59218c0ed9 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -43,20 +43,26 @@ ) from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports +from google.cloud.aiplatform_v1beta1.types import artifact +from google.cloud.aiplatform_v1beta1.types import context from google.cloud.aiplatform_v1beta1.types import deployed_model_ref from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import env_var +from google.cloud.aiplatform_v1beta1.types import execution from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import explanation_metadata from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_job +from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline from google.cloud.aiplatform_v1beta1.types import ( training_pipeline as gca_training_pipeline, ) +from google.cloud.aiplatform_v1beta1.types import value from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import any_pb2 as gp_any # type: ignore @@ -1808,6 +1814,1321 @@ async def test_cancel_training_pipeline_flattened_error_async(): ) +def test_create_pipeline_job( + transport: str = "grpc", request_type=pipeline_service.CreatePipelineJobRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_pipeline_job.PipelineJob( + name="name_value", + display_name="display_name_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + service_account="service_account_value", + network="network_value", + ) + + response = client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreatePipelineJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_pipeline_job.PipelineJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + assert response.service_account == "service_account_value" + + assert response.network == "network_value" + + +def test_create_pipeline_job_from_dict(): + test_create_pipeline_job(request_type=dict) + + +def test_create_pipeline_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_pipeline_job), "__call__" + ) as call: + client.create_pipeline_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreatePipelineJobRequest() + + +@pytest.mark.asyncio +async def test_create_pipeline_job_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CreatePipelineJobRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_pipeline_job.PipelineJob( + name="name_value", + display_name="display_name_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + service_account="service_account_value", + network="network_value", + ) + ) + + response = await client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreatePipelineJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_pipeline_job.PipelineJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + assert response.service_account == "service_account_value" + + assert response.network == "network_value" + + +@pytest.mark.asyncio +async def test_create_pipeline_job_async_from_dict(): + await test_create_pipeline_job_async(request_type=dict) + + +def test_create_pipeline_job_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreatePipelineJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_pipeline_job), "__call__" + ) as call: + call.return_value = gca_pipeline_job.PipelineJob() + + client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_pipeline_job_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreatePipelineJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_pipeline_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_pipeline_job.PipelineJob() + ) + + await client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_pipeline_job_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_pipeline_job.PipelineJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_pipeline_job( + parent="parent_value", + pipeline_job=gca_pipeline_job.PipelineJob(name="name_value"), + pipeline_job_id="pipeline_job_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].pipeline_job == gca_pipeline_job.PipelineJob(name="name_value") + + assert args[0].pipeline_job_id == "pipeline_job_id_value" + + +def test_create_pipeline_job_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_pipeline_job( + pipeline_service.CreatePipelineJobRequest(), + parent="parent_value", + pipeline_job=gca_pipeline_job.PipelineJob(name="name_value"), + pipeline_job_id="pipeline_job_id_value", + ) + + +@pytest.mark.asyncio +async def test_create_pipeline_job_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_pipeline_job.PipelineJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_pipeline_job.PipelineJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_pipeline_job( + parent="parent_value", + pipeline_job=gca_pipeline_job.PipelineJob(name="name_value"), + pipeline_job_id="pipeline_job_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].pipeline_job == gca_pipeline_job.PipelineJob(name="name_value") + + assert args[0].pipeline_job_id == "pipeline_job_id_value" + + +@pytest.mark.asyncio +async def test_create_pipeline_job_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_pipeline_job( + pipeline_service.CreatePipelineJobRequest(), + parent="parent_value", + pipeline_job=gca_pipeline_job.PipelineJob(name="name_value"), + pipeline_job_id="pipeline_job_id_value", + ) + + +def test_get_pipeline_job( + transport: str = "grpc", request_type=pipeline_service.GetPipelineJobRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_job.PipelineJob( + name="name_value", + display_name="display_name_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + service_account="service_account_value", + network="network_value", + ) + + response = client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetPipelineJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pipeline_job.PipelineJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + assert response.service_account == "service_account_value" + + assert response.network == "network_value" + + +def test_get_pipeline_job_from_dict(): + test_get_pipeline_job(request_type=dict) + + +def test_get_pipeline_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + client.get_pipeline_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetPipelineJobRequest() + + +@pytest.mark.asyncio +async def test_get_pipeline_job_async( + transport: str = "grpc_asyncio", request_type=pipeline_service.GetPipelineJobRequest +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_job.PipelineJob( + name="name_value", + display_name="display_name_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + service_account="service_account_value", + network="network_value", + ) + ) + + response = await client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetPipelineJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pipeline_job.PipelineJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + assert response.service_account == "service_account_value" + + assert response.network == "network_value" + + +@pytest.mark.asyncio +async def test_get_pipeline_job_async_from_dict(): + await test_get_pipeline_job_async(request_type=dict) + + +def test_get_pipeline_job_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetPipelineJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + call.return_value = pipeline_job.PipelineJob() + + client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_pipeline_job_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetPipelineJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_job.PipelineJob() + ) + + await client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_pipeline_job_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_job.PipelineJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_pipeline_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_pipeline_job_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_pipeline_job( + pipeline_service.GetPipelineJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_pipeline_job_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_job.PipelineJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_job.PipelineJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_pipeline_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_pipeline_job_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_pipeline_job( + pipeline_service.GetPipelineJobRequest(), name="name_value", + ) + + +def test_list_pipeline_jobs( + transport: str = "grpc", request_type=pipeline_service.ListPipelineJobsRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListPipelineJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListPipelineJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListPipelineJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_pipeline_jobs_from_dict(): + test_list_pipeline_jobs(request_type=dict) + + +def test_list_pipeline_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + client.list_pipeline_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListPipelineJobsRequest() + + +@pytest.mark.asyncio +async def test_list_pipeline_jobs_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.ListPipelineJobsRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListPipelineJobsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListPipelineJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListPipelineJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_pipeline_jobs_async_from_dict(): + await test_list_pipeline_jobs_async(request_type=dict) + + +def test_list_pipeline_jobs_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListPipelineJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + call.return_value = pipeline_service.ListPipelineJobsResponse() + + client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_pipeline_jobs_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListPipelineJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListPipelineJobsResponse() + ) + + await client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_pipeline_jobs_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListPipelineJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_pipeline_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_pipeline_jobs_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_pipeline_jobs( + pipeline_service.ListPipelineJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_pipeline_jobs_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListPipelineJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListPipelineJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_pipeline_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_pipeline_jobs_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_pipeline_jobs( + pipeline_service.ListPipelineJobsRequest(), parent="parent_value", + ) + + +def test_list_pipeline_jobs_pager(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[ + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + ], + next_page_token="abc", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[], next_page_token="def", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(),], next_page_token="ghi", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(), pipeline_job.PipelineJob(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_pipeline_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, pipeline_job.PipelineJob) for i in results) + + +def test_list_pipeline_jobs_pages(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[ + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + ], + next_page_token="abc", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[], next_page_token="def", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(),], next_page_token="ghi", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(), pipeline_job.PipelineJob(),], + ), + RuntimeError, + ) + pages = list(client.list_pipeline_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_pipeline_jobs_async_pager(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[ + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + ], + next_page_token="abc", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[], next_page_token="def", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(),], next_page_token="ghi", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(), pipeline_job.PipelineJob(),], + ), + RuntimeError, + ) + async_pager = await client.list_pipeline_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, pipeline_job.PipelineJob) for i in responses) + + +@pytest.mark.asyncio +async def test_list_pipeline_jobs_async_pages(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_pipeline_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[ + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + pipeline_job.PipelineJob(), + ], + next_page_token="abc", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[], next_page_token="def", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(),], next_page_token="ghi", + ), + pipeline_service.ListPipelineJobsResponse( + pipeline_jobs=[pipeline_job.PipelineJob(), pipeline_job.PipelineJob(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_pipeline_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_pipeline_job( + transport: str = "grpc", request_type=pipeline_service.DeletePipelineJobRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeletePipelineJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_pipeline_job_from_dict(): + test_delete_pipeline_job(request_type=dict) + + +def test_delete_pipeline_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_pipeline_job), "__call__" + ) as call: + client.delete_pipeline_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeletePipelineJobRequest() + + +@pytest.mark.asyncio +async def test_delete_pipeline_job_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.DeletePipelineJobRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeletePipelineJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_pipeline_job_async_from_dict(): + await test_delete_pipeline_job_async(request_type=dict) + + +def test_delete_pipeline_job_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeletePipelineJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_pipeline_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_pipeline_job_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeletePipelineJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_pipeline_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_pipeline_job_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_pipeline_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_pipeline_job_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_pipeline_job( + pipeline_service.DeletePipelineJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_pipeline_job_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_pipeline_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_pipeline_job_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_pipeline_job( + pipeline_service.DeletePipelineJobRequest(), name="name_value", + ) + + +def test_cancel_pipeline_job( + transport: str = "grpc", request_type=pipeline_service.CancelPipelineJobRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelPipelineJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_pipeline_job_from_dict(): + test_cancel_pipeline_job(request_type=dict) + + +def test_cancel_pipeline_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_pipeline_job), "__call__" + ) as call: + client.cancel_pipeline_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelPipelineJobRequest() + + +@pytest.mark.asyncio +async def test_cancel_pipeline_job_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CancelPipelineJobRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelPipelineJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_pipeline_job_async_from_dict(): + await test_cancel_pipeline_job_async(request_type=dict) + + +def test_cancel_pipeline_job_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelPipelineJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_pipeline_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_pipeline_job_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelPipelineJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_pipeline_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_pipeline_job_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_pipeline_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_pipeline_job_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_pipeline_job( + pipeline_service.CancelPipelineJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_pipeline_job_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_pipeline_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_pipeline_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_pipeline_job_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_pipeline_job( + pipeline_service.CancelPipelineJobRequest(), name="name_value", + ) + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.PipelineServiceGrpcTransport( @@ -1910,6 +3231,11 @@ def test_pipeline_service_base_transport(): "list_training_pipelines", "delete_training_pipeline", "cancel_training_pipeline", + "create_pipeline_job", + "get_pipeline_job", + "list_pipeline_jobs", + "delete_pipeline_job", + "cancel_pipeline_job", ) for method in methods: with pytest.raises(NotImplementedError): @@ -2191,10 +3517,99 @@ def test_pipeline_service_grpc_lro_async_client(): assert transport.operations_client is transport.operations_client -def test_endpoint_path(): +def test_artifact_path(): project = "squid" location = "clam" - endpoint = "whelk" + metadata_store = "whelk" + artifact = "octopus" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) + actual = PipelineServiceClient.artifact_path( + project, location, metadata_store, artifact + ) + assert expected == actual + + +def test_parse_artifact_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "artifact": "mussel", + } + path = PipelineServiceClient.artifact_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_artifact_path(path) + assert expected == actual + + +def test_context_path(): + project = "winkle" + location = "nautilus" + metadata_store = "scallop" + context = "abalone" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + actual = PipelineServiceClient.context_path( + project, location, metadata_store, context + ) + assert expected == actual + + +def test_parse_context_path(): + expected = { + "project": "squid", + "location": "clam", + "metadata_store": "whelk", + "context": "octopus", + } + path = PipelineServiceClient.context_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_context_path(path) + assert expected == actual + + +def test_custom_job_path(): + project = "oyster" + location = "nudibranch" + custom_job = "cuttlefish" + + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + actual = PipelineServiceClient.custom_job_path(project, location, custom_job) + assert expected == actual + + +def test_parse_custom_job_path(): + expected = { + "project": "mussel", + "location": "winkle", + "custom_job": "nautilus", + } + path = PipelineServiceClient.custom_job_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_custom_job_path(path) + assert expected == actual + + +def test_endpoint_path(): + project = "scallop" + location = "abalone" + endpoint = "squid" expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( project=project, location=location, endpoint=endpoint, @@ -2205,9 +3620,9 @@ def test_endpoint_path(): def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "clam", + "location": "whelk", + "endpoint": "octopus", } path = PipelineServiceClient.endpoint_path(**expected) @@ -2216,10 +3631,42 @@ def test_parse_endpoint_path(): assert expected == actual +def test_execution_path(): + project = "oyster" + location = "nudibranch" + metadata_store = "cuttlefish" + execution = "mussel" + + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) + actual = PipelineServiceClient.execution_path( + project, location, metadata_store, execution + ) + assert expected == actual + + +def test_parse_execution_path(): + expected = { + "project": "winkle", + "location": "nautilus", + "metadata_store": "scallop", + "execution": "abalone", + } + path = PipelineServiceClient.execution_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_execution_path(path) + assert expected == actual + + def test_model_path(): - project = "cuttlefish" - location = "mussel" - model = "winkle" + project = "squid" + location = "clam" + model = "whelk" expected = "projects/{project}/locations/{location}/models/{model}".format( project=project, location=location, model=model, @@ -2230,9 +3677,9 @@ def test_model_path(): def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "octopus", + "location": "oyster", + "model": "nudibranch", } path = PipelineServiceClient.model_path(**expected) @@ -2241,10 +3688,58 @@ def test_parse_model_path(): assert expected == actual +def test_network_path(): + project = "cuttlefish" + network = "mussel" + + expected = "projects/{project}/global/networks/{network}".format( + project=project, network=network, + ) + actual = PipelineServiceClient.network_path(project, network) + assert expected == actual + + +def test_parse_network_path(): + expected = { + "project": "winkle", + "network": "nautilus", + } + path = PipelineServiceClient.network_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_network_path(path) + assert expected == actual + + +def test_pipeline_job_path(): + project = "scallop" + location = "abalone" + pipeline_job = "squid" + + expected = "projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}".format( + project=project, location=location, pipeline_job=pipeline_job, + ) + actual = PipelineServiceClient.pipeline_job_path(project, location, pipeline_job) + assert expected == actual + + +def test_parse_pipeline_job_path(): + expected = { + "project": "clam", + "location": "whelk", + "pipeline_job": "octopus", + } + path = PipelineServiceClient.pipeline_job_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_pipeline_job_path(path) + assert expected == actual + + def test_training_pipeline_path(): - project = "squid" - location = "clam" - training_pipeline = "whelk" + project = "oyster" + location = "nudibranch" + training_pipeline = "cuttlefish" expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( project=project, location=location, training_pipeline=training_pipeline, @@ -2257,9 +3752,9 @@ def test_training_pipeline_path(): def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", + "project": "mussel", + "location": "winkle", + "training_pipeline": "nautilus", } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2269,7 +3764,7 @@ def test_parse_training_pipeline_path(): def test_common_billing_account_path(): - billing_account = "cuttlefish" + billing_account = "scallop" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, @@ -2280,7 +3775,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "abalone", } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2290,7 +3785,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "winkle" + folder = "squid" expected = "folders/{folder}".format(folder=folder,) actual = PipelineServiceClient.common_folder_path(folder) @@ -2299,7 +3794,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "clam", } path = PipelineServiceClient.common_folder_path(**expected) @@ -2309,7 +3804,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "scallop" + organization = "whelk" expected = "organizations/{organization}".format(organization=organization,) actual = PipelineServiceClient.common_organization_path(organization) @@ -2318,7 +3813,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "octopus", } path = PipelineServiceClient.common_organization_path(**expected) @@ -2328,7 +3823,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "squid" + project = "oyster" expected = "projects/{project}".format(project=project,) actual = PipelineServiceClient.common_project_path(project) @@ -2337,7 +3832,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "nudibranch", } path = PipelineServiceClient.common_project_path(**expected) @@ -2347,8 +3842,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "whelk" - location = "octopus" + project = "cuttlefish" + location = "mussel" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -2359,8 +3854,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "winkle", + "location": "nautilus", } path = PipelineServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py new file mode 100644 index 0000000000..cfbde666ce --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py @@ -0,0 +1,8115 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import ( + TensorboardServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import ( + TensorboardServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import pagers +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import transports +from google.cloud.aiplatform_v1beta1.types import encryption_spec +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard as gca_tensorboard +from google.cloud.aiplatform_v1beta1.types import tensorboard_data +from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_experiment as gca_tensorboard_experiment, +) +from google.cloud.aiplatform_v1beta1.types import tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform_v1beta1.types import tensorboard_service +from google.cloud.aiplatform_v1beta1.types import tensorboard_time_series +from google.cloud.aiplatform_v1beta1.types import ( + tensorboard_time_series as gca_tensorboard_time_series, +) +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert TensorboardServiceClient._get_default_mtls_endpoint(None) is None + assert ( + TensorboardServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + TensorboardServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + TensorboardServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + TensorboardServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + TensorboardServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [TensorboardServiceClient, TensorboardServiceAsyncClient,] +) +def test_tensorboard_service_client_from_service_account_info(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize( + "client_class", [TensorboardServiceClient, TensorboardServiceAsyncClient,] +) +def test_tensorboard_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_tensorboard_service_client_get_transport_class(): + transport = TensorboardServiceClient.get_transport_class() + available_transports = [ + transports.TensorboardServiceGrpcTransport, + ] + assert transport in available_transports + + transport = TensorboardServiceClient.get_transport_class("grpc") + assert transport == transports.TensorboardServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TensorboardServiceClient, transports.TensorboardServiceGrpcTransport, "grpc"), + ( + TensorboardServiceAsyncClient, + transports.TensorboardServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + TensorboardServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceClient), +) +@mock.patch.object( + TensorboardServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceAsyncClient), +) +def test_tensorboard_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(TensorboardServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(TensorboardServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + TensorboardServiceClient, + transports.TensorboardServiceGrpcTransport, + "grpc", + "true", + ), + ( + TensorboardServiceAsyncClient, + transports.TensorboardServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + TensorboardServiceClient, + transports.TensorboardServiceGrpcTransport, + "grpc", + "false", + ), + ( + TensorboardServiceAsyncClient, + transports.TensorboardServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + TensorboardServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceClient), +) +@mock.patch.object( + TensorboardServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TensorboardServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_tensorboard_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TensorboardServiceClient, transports.TensorboardServiceGrpcTransport, "grpc"), + ( + TensorboardServiceAsyncClient, + transports.TensorboardServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_tensorboard_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TensorboardServiceClient, transports.TensorboardServiceGrpcTransport, "grpc"), + ( + TensorboardServiceAsyncClient, + transports.TensorboardServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_tensorboard_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_tensorboard_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports.TensorboardServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = TensorboardServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_tensorboard( + transport: str = "grpc", request_type=tensorboard_service.CreateTensorboardRequest +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_tensorboard_from_dict(): + test_create_tensorboard(request_type=dict) + + +def test_create_tensorboard_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard), "__call__" + ) as call: + client.create_tensorboard() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardRequest() + + +@pytest.mark.asyncio +async def test_create_tensorboard_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.CreateTensorboardRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_tensorboard_async_from_dict(): + await test_create_tensorboard_async(request_type=dict) + + +def test_create_tensorboard_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_tensorboard_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_tensorboard_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_tensorboard( + parent="parent_value", + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].tensorboard == gca_tensorboard.Tensorboard(name="name_value") + + +def test_create_tensorboard_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tensorboard( + tensorboard_service.CreateTensorboardRequest(), + parent="parent_value", + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_tensorboard_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_tensorboard( + parent="parent_value", + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].tensorboard == gca_tensorboard.Tensorboard(name="name_value") + + +@pytest.mark.asyncio +async def test_create_tensorboard_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_tensorboard( + tensorboard_service.CreateTensorboardRequest(), + parent="parent_value", + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + ) + + +def test_get_tensorboard( + transport: str = "grpc", request_type=tensorboard_service.GetTensorboardRequest +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard.Tensorboard( + name="name_value", + display_name="display_name_value", + description="description_value", + blob_storage_path_prefix="blob_storage_path_prefix_value", + run_count=989, + etag="etag_value", + ) + + response = client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, tensorboard.Tensorboard) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.blob_storage_path_prefix == "blob_storage_path_prefix_value" + + assert response.run_count == 989 + + assert response.etag == "etag_value" + + +def test_get_tensorboard_from_dict(): + test_get_tensorboard(request_type=dict) + + +def test_get_tensorboard_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + client.get_tensorboard() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardRequest() + + +@pytest.mark.asyncio +async def test_get_tensorboard_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.GetTensorboardRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard.Tensorboard( + name="name_value", + display_name="display_name_value", + description="description_value", + blob_storage_path_prefix="blob_storage_path_prefix_value", + run_count=989, + etag="etag_value", + ) + ) + + response = await client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, tensorboard.Tensorboard) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.blob_storage_path_prefix == "blob_storage_path_prefix_value" + + assert response.run_count == 989 + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_tensorboard_async_from_dict(): + await test_get_tensorboard_async(request_type=dict) + + +def test_get_tensorboard_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + call.return_value = tensorboard.Tensorboard() + + client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_tensorboard_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard.Tensorboard() + ) + + await client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_tensorboard_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard.Tensorboard() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_tensorboard(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_tensorboard_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tensorboard( + tensorboard_service.GetTensorboardRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_tensorboard_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard.Tensorboard() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard.Tensorboard() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_tensorboard(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_tensorboard_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_tensorboard( + tensorboard_service.GetTensorboardRequest(), name="name_value", + ) + + +def test_update_tensorboard( + transport: str = "grpc", request_type=tensorboard_service.UpdateTensorboardRequest +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_tensorboard_from_dict(): + test_update_tensorboard(request_type=dict) + + +def test_update_tensorboard_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard), "__call__" + ) as call: + client.update_tensorboard() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardRequest() + + +@pytest.mark.asyncio +async def test_update_tensorboard_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.UpdateTensorboardRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_tensorboard_async_from_dict(): + await test_update_tensorboard_async(request_type=dict) + + +def test_update_tensorboard_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardRequest() + request.tensorboard.name = "tensorboard.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "tensorboard.name=tensorboard.name/value",) in kw[ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_update_tensorboard_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardRequest() + request.tensorboard.name = "tensorboard.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "tensorboard.name=tensorboard.name/value",) in kw[ + "metadata" + ] + + +def test_update_tensorboard_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_tensorboard( + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard == gca_tensorboard.Tensorboard(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_tensorboard_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tensorboard( + tensorboard_service.UpdateTensorboardRequest(), + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_tensorboard_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_tensorboard( + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard == gca_tensorboard.Tensorboard(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_tensorboard_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_tensorboard( + tensorboard_service.UpdateTensorboardRequest(), + tensorboard=gca_tensorboard.Tensorboard(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_list_tensorboards( + transport: str = "grpc", request_type=tensorboard_service.ListTensorboardsRequest +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListTensorboardsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_tensorboards_from_dict(): + test_list_tensorboards(request_type=dict) + + +def test_list_tensorboards_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + client.list_tensorboards() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardsRequest() + + +@pytest.mark.asyncio +async def test_list_tensorboards_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.ListTensorboardsRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTensorboardsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_tensorboards_async_from_dict(): + await test_list_tensorboards_async(request_type=dict) + + +def test_list_tensorboards_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + call.return_value = tensorboard_service.ListTensorboardsResponse() + + client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_tensorboards_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardsResponse() + ) + + await client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_tensorboards_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_tensorboards(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_tensorboards_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tensorboards( + tensorboard_service.ListTensorboardsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_tensorboards_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_tensorboards(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_tensorboards_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_tensorboards( + tensorboard_service.ListTensorboardsRequest(), parent="parent_value", + ) + + +def test_list_tensorboards_pager(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardsResponse( + tensorboards=[ + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(),], next_page_token="ghi", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(), tensorboard.Tensorboard(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_tensorboards(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, tensorboard.Tensorboard) for i in results) + + +def test_list_tensorboards_pages(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardsResponse( + tensorboards=[ + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(),], next_page_token="ghi", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(), tensorboard.Tensorboard(),], + ), + RuntimeError, + ) + pages = list(client.list_tensorboards(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_tensorboards_async_pager(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardsResponse( + tensorboards=[ + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(),], next_page_token="ghi", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(), tensorboard.Tensorboard(),], + ), + RuntimeError, + ) + async_pager = await client.list_tensorboards(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, tensorboard.Tensorboard) for i in responses) + + +@pytest.mark.asyncio +async def test_list_tensorboards_async_pages(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboards), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardsResponse( + tensorboards=[ + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + tensorboard.Tensorboard(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(),], next_page_token="ghi", + ), + tensorboard_service.ListTensorboardsResponse( + tensorboards=[tensorboard.Tensorboard(), tensorboard.Tensorboard(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_tensorboards(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_tensorboard( + transport: str = "grpc", request_type=tensorboard_service.DeleteTensorboardRequest +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_tensorboard_from_dict(): + test_delete_tensorboard(request_type=dict) + + +def test_delete_tensorboard_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard), "__call__" + ) as call: + client.delete_tensorboard() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardRequest() + + +@pytest.mark.asyncio +async def test_delete_tensorboard_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.DeleteTensorboardRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_async_from_dict(): + await test_delete_tensorboard_async(request_type=dict) + + +def test_delete_tensorboard_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_tensorboard_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_tensorboard_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_tensorboard(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_tensorboard_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tensorboard( + tensorboard_service.DeleteTensorboardRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_tensorboard(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_tensorboard_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_tensorboard( + tensorboard_service.DeleteTensorboardRequest(), name="name_value", + ) + + +def test_create_tensorboard_experiment( + transport: str = "grpc", + request_type=tensorboard_service.CreateTensorboardExperimentRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_experiment.TensorboardExperiment( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + source="source_value", + ) + + response = client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_tensorboard_experiment.TensorboardExperiment) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + assert response.source == "source_value" + + +def test_create_tensorboard_experiment_from_dict(): + test_create_tensorboard_experiment(request_type=dict) + + +def test_create_tensorboard_experiment_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_experiment), "__call__" + ) as call: + client.create_tensorboard_experiment() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardExperimentRequest() + + +@pytest.mark.asyncio +async def test_create_tensorboard_experiment_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.CreateTensorboardExperimentRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_experiment.TensorboardExperiment( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + source="source_value", + ) + ) + + response = await client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_tensorboard_experiment.TensorboardExperiment) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + assert response.source == "source_value" + + +@pytest.mark.asyncio +async def test_create_tensorboard_experiment_async_from_dict(): + await test_create_tensorboard_experiment_async(request_type=dict) + + +def test_create_tensorboard_experiment_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardExperimentRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_experiment), "__call__" + ) as call: + call.return_value = gca_tensorboard_experiment.TensorboardExperiment() + + client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_tensorboard_experiment_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardExperimentRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_experiment), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_experiment.TensorboardExperiment() + ) + + await client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_tensorboard_experiment_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_experiment.TensorboardExperiment() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_tensorboard_experiment( + parent="parent_value", + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + tensorboard_experiment_id="tensorboard_experiment_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].tensorboard_experiment == gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ) + + assert args[0].tensorboard_experiment_id == "tensorboard_experiment_id_value" + + +def test_create_tensorboard_experiment_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tensorboard_experiment( + tensorboard_service.CreateTensorboardExperimentRequest(), + parent="parent_value", + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + tensorboard_experiment_id="tensorboard_experiment_id_value", + ) + + +@pytest.mark.asyncio +async def test_create_tensorboard_experiment_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_experiment.TensorboardExperiment() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_experiment.TensorboardExperiment() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_tensorboard_experiment( + parent="parent_value", + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + tensorboard_experiment_id="tensorboard_experiment_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].tensorboard_experiment == gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ) + + assert args[0].tensorboard_experiment_id == "tensorboard_experiment_id_value" + + +@pytest.mark.asyncio +async def test_create_tensorboard_experiment_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_tensorboard_experiment( + tensorboard_service.CreateTensorboardExperimentRequest(), + parent="parent_value", + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + tensorboard_experiment_id="tensorboard_experiment_id_value", + ) + + +def test_get_tensorboard_experiment( + transport: str = "grpc", + request_type=tensorboard_service.GetTensorboardExperimentRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_experiment.TensorboardExperiment( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + source="source_value", + ) + + response = client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, tensorboard_experiment.TensorboardExperiment) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + assert response.source == "source_value" + + +def test_get_tensorboard_experiment_from_dict(): + test_get_tensorboard_experiment(request_type=dict) + + +def test_get_tensorboard_experiment_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_experiment), "__call__" + ) as call: + client.get_tensorboard_experiment() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardExperimentRequest() + + +@pytest.mark.asyncio +async def test_get_tensorboard_experiment_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.GetTensorboardExperimentRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_experiment.TensorboardExperiment( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + source="source_value", + ) + ) + + response = await client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, tensorboard_experiment.TensorboardExperiment) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + assert response.source == "source_value" + + +@pytest.mark.asyncio +async def test_get_tensorboard_experiment_async_from_dict(): + await test_get_tensorboard_experiment_async(request_type=dict) + + +def test_get_tensorboard_experiment_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardExperimentRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_experiment), "__call__" + ) as call: + call.return_value = tensorboard_experiment.TensorboardExperiment() + + client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_tensorboard_experiment_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardExperimentRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_experiment), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_experiment.TensorboardExperiment() + ) + + await client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_tensorboard_experiment_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_experiment.TensorboardExperiment() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_tensorboard_experiment(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_tensorboard_experiment_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tensorboard_experiment( + tensorboard_service.GetTensorboardExperimentRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_tensorboard_experiment_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_experiment.TensorboardExperiment() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_experiment.TensorboardExperiment() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_tensorboard_experiment(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_tensorboard_experiment_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_tensorboard_experiment( + tensorboard_service.GetTensorboardExperimentRequest(), name="name_value", + ) + + +def test_update_tensorboard_experiment( + transport: str = "grpc", + request_type=tensorboard_service.UpdateTensorboardExperimentRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_experiment.TensorboardExperiment( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + source="source_value", + ) + + response = client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_tensorboard_experiment.TensorboardExperiment) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + assert response.source == "source_value" + + +def test_update_tensorboard_experiment_from_dict(): + test_update_tensorboard_experiment(request_type=dict) + + +def test_update_tensorboard_experiment_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_experiment), "__call__" + ) as call: + client.update_tensorboard_experiment() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardExperimentRequest() + + +@pytest.mark.asyncio +async def test_update_tensorboard_experiment_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.UpdateTensorboardExperimentRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_experiment.TensorboardExperiment( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + source="source_value", + ) + ) + + response = await client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_tensorboard_experiment.TensorboardExperiment) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + assert response.source == "source_value" + + +@pytest.mark.asyncio +async def test_update_tensorboard_experiment_async_from_dict(): + await test_update_tensorboard_experiment_async(request_type=dict) + + +def test_update_tensorboard_experiment_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardExperimentRequest() + request.tensorboard_experiment.name = "tensorboard_experiment.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_experiment), "__call__" + ) as call: + call.return_value = gca_tensorboard_experiment.TensorboardExperiment() + + client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_experiment.name=tensorboard_experiment.name/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_tensorboard_experiment_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardExperimentRequest() + request.tensorboard_experiment.name = "tensorboard_experiment.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_experiment), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_experiment.TensorboardExperiment() + ) + + await client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_experiment.name=tensorboard_experiment.name/value", + ) in kw["metadata"] + + +def test_update_tensorboard_experiment_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_experiment.TensorboardExperiment() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_tensorboard_experiment( + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[ + 0 + ].tensorboard_experiment == gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_tensorboard_experiment_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tensorboard_experiment( + tensorboard_service.UpdateTensorboardExperimentRequest(), + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_tensorboard_experiment_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_experiment.TensorboardExperiment() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_experiment.TensorboardExperiment() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_tensorboard_experiment( + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[ + 0 + ].tensorboard_experiment == gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_tensorboard_experiment_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_tensorboard_experiment( + tensorboard_service.UpdateTensorboardExperimentRequest(), + tensorboard_experiment=gca_tensorboard_experiment.TensorboardExperiment( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_list_tensorboard_experiments( + transport: str = "grpc", + request_type=tensorboard_service.ListTensorboardExperimentsRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardExperimentsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardExperimentsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListTensorboardExperimentsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_tensorboard_experiments_from_dict(): + test_list_tensorboard_experiments(request_type=dict) + + +def test_list_tensorboard_experiments_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + client.list_tensorboard_experiments() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardExperimentsRequest() + + +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.ListTensorboardExperimentsRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardExperimentsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardExperimentsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTensorboardExperimentsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_async_from_dict(): + await test_list_tensorboard_experiments_async(request_type=dict) + + +def test_list_tensorboard_experiments_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardExperimentsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + call.return_value = tensorboard_service.ListTensorboardExperimentsResponse() + + client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardExperimentsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardExperimentsResponse() + ) + + await client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_tensorboard_experiments_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardExperimentsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_tensorboard_experiments(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_tensorboard_experiments_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tensorboard_experiments( + tensorboard_service.ListTensorboardExperimentsRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardExperimentsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardExperimentsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_tensorboard_experiments(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_tensorboard_experiments( + tensorboard_service.ListTensorboardExperimentsRequest(), + parent="parent_value", + ) + + +def test_list_tensorboard_experiments_pager(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_tensorboard_experiments(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, tensorboard_experiment.TensorboardExperiment) for i in results + ) + + +def test_list_tensorboard_experiments_pages(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + ), + RuntimeError, + ) + pages = list(client.list_tensorboard_experiments(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_async_pager(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_tensorboard_experiments(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, tensorboard_experiment.TensorboardExperiment) + for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_async_pages(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_experiments), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardExperimentsResponse( + tensorboard_experiments=[ + tensorboard_experiment.TensorboardExperiment(), + tensorboard_experiment.TensorboardExperiment(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in ( + await client.list_tensorboard_experiments(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_tensorboard_experiment( + transport: str = "grpc", + request_type=tensorboard_service.DeleteTensorboardExperimentRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_tensorboard_experiment_from_dict(): + test_delete_tensorboard_experiment(request_type=dict) + + +def test_delete_tensorboard_experiment_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_experiment), "__call__" + ) as call: + client.delete_tensorboard_experiment() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardExperimentRequest() + + +@pytest.mark.asyncio +async def test_delete_tensorboard_experiment_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.DeleteTensorboardExperimentRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardExperimentRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_experiment_async_from_dict(): + await test_delete_tensorboard_experiment_async(request_type=dict) + + +def test_delete_tensorboard_experiment_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardExperimentRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_experiment), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_tensorboard_experiment_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardExperimentRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_experiment), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_tensorboard_experiment_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_tensorboard_experiment(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_tensorboard_experiment_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tensorboard_experiment( + tensorboard_service.DeleteTensorboardExperimentRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_experiment_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_experiment), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_tensorboard_experiment(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_tensorboard_experiment_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_tensorboard_experiment( + tensorboard_service.DeleteTensorboardExperimentRequest(), name="name_value", + ) + + +def test_create_tensorboard_run( + transport: str = "grpc", + request_type=tensorboard_service.CreateTensorboardRunRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_run.TensorboardRun( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardRunRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_tensorboard_run.TensorboardRun) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +def test_create_tensorboard_run_from_dict(): + test_create_tensorboard_run(request_type=dict) + + +def test_create_tensorboard_run_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_run), "__call__" + ) as call: + client.create_tensorboard_run() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardRunRequest() + + +@pytest.mark.asyncio +async def test_create_tensorboard_run_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.CreateTensorboardRunRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_run.TensorboardRun( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) + + response = await client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardRunRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_tensorboard_run.TensorboardRun) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_create_tensorboard_run_async_from_dict(): + await test_create_tensorboard_run_async(request_type=dict) + + +def test_create_tensorboard_run_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardRunRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_run), "__call__" + ) as call: + call.return_value = gca_tensorboard_run.TensorboardRun() + + client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_tensorboard_run_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardRunRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_run), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_run.TensorboardRun() + ) + + await client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_tensorboard_run_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_run.TensorboardRun() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_tensorboard_run( + parent="parent_value", + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + tensorboard_run_id="tensorboard_run_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].tensorboard_run == gca_tensorboard_run.TensorboardRun( + name="name_value" + ) + + assert args[0].tensorboard_run_id == "tensorboard_run_id_value" + + +def test_create_tensorboard_run_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tensorboard_run( + tensorboard_service.CreateTensorboardRunRequest(), + parent="parent_value", + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + tensorboard_run_id="tensorboard_run_id_value", + ) + + +@pytest.mark.asyncio +async def test_create_tensorboard_run_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_run.TensorboardRun() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_run.TensorboardRun() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_tensorboard_run( + parent="parent_value", + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + tensorboard_run_id="tensorboard_run_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].tensorboard_run == gca_tensorboard_run.TensorboardRun( + name="name_value" + ) + + assert args[0].tensorboard_run_id == "tensorboard_run_id_value" + + +@pytest.mark.asyncio +async def test_create_tensorboard_run_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_tensorboard_run( + tensorboard_service.CreateTensorboardRunRequest(), + parent="parent_value", + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + tensorboard_run_id="tensorboard_run_id_value", + ) + + +def test_get_tensorboard_run( + transport: str = "grpc", request_type=tensorboard_service.GetTensorboardRunRequest +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_run.TensorboardRun( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardRunRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, tensorboard_run.TensorboardRun) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +def test_get_tensorboard_run_from_dict(): + test_get_tensorboard_run(request_type=dict) + + +def test_get_tensorboard_run_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_run), "__call__" + ) as call: + client.get_tensorboard_run() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardRunRequest() + + +@pytest.mark.asyncio +async def test_get_tensorboard_run_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.GetTensorboardRunRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_run.TensorboardRun( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) + + response = await client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardRunRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, tensorboard_run.TensorboardRun) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_tensorboard_run_async_from_dict(): + await test_get_tensorboard_run_async(request_type=dict) + + +def test_get_tensorboard_run_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardRunRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_run), "__call__" + ) as call: + call.return_value = tensorboard_run.TensorboardRun() + + client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_tensorboard_run_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardRunRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_run), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_run.TensorboardRun() + ) + + await client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_tensorboard_run_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_run.TensorboardRun() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_tensorboard_run(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_tensorboard_run_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tensorboard_run( + tensorboard_service.GetTensorboardRunRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_tensorboard_run_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_run.TensorboardRun() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_run.TensorboardRun() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_tensorboard_run(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_tensorboard_run_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_tensorboard_run( + tensorboard_service.GetTensorboardRunRequest(), name="name_value", + ) + + +def test_update_tensorboard_run( + transport: str = "grpc", + request_type=tensorboard_service.UpdateTensorboardRunRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_run.TensorboardRun( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardRunRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_tensorboard_run.TensorboardRun) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +def test_update_tensorboard_run_from_dict(): + test_update_tensorboard_run(request_type=dict) + + +def test_update_tensorboard_run_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_run), "__call__" + ) as call: + client.update_tensorboard_run() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardRunRequest() + + +@pytest.mark.asyncio +async def test_update_tensorboard_run_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.UpdateTensorboardRunRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_run.TensorboardRun( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) + + response = await client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardRunRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_tensorboard_run.TensorboardRun) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_update_tensorboard_run_async_from_dict(): + await test_update_tensorboard_run_async(request_type=dict) + + +def test_update_tensorboard_run_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardRunRequest() + request.tensorboard_run.name = "tensorboard_run.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_run), "__call__" + ) as call: + call.return_value = gca_tensorboard_run.TensorboardRun() + + client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_run.name=tensorboard_run.name/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_tensorboard_run_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardRunRequest() + request.tensorboard_run.name = "tensorboard_run.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_run), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_run.TensorboardRun() + ) + + await client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_run.name=tensorboard_run.name/value", + ) in kw["metadata"] + + +def test_update_tensorboard_run_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_run.TensorboardRun() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_tensorboard_run( + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_run == gca_tensorboard_run.TensorboardRun( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_tensorboard_run_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tensorboard_run( + tensorboard_service.UpdateTensorboardRunRequest(), + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_tensorboard_run_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_run.TensorboardRun() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_run.TensorboardRun() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_tensorboard_run( + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_run == gca_tensorboard_run.TensorboardRun( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_tensorboard_run_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_tensorboard_run( + tensorboard_service.UpdateTensorboardRunRequest(), + tensorboard_run=gca_tensorboard_run.TensorboardRun(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_list_tensorboard_runs( + transport: str = "grpc", request_type=tensorboard_service.ListTensorboardRunsRequest +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardRunsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardRunsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListTensorboardRunsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_tensorboard_runs_from_dict(): + test_list_tensorboard_runs(request_type=dict) + + +def test_list_tensorboard_runs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + client.list_tensorboard_runs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardRunsRequest() + + +@pytest.mark.asyncio +async def test_list_tensorboard_runs_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.ListTensorboardRunsRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardRunsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardRunsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTensorboardRunsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_tensorboard_runs_async_from_dict(): + await test_list_tensorboard_runs_async(request_type=dict) + + +def test_list_tensorboard_runs_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardRunsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + call.return_value = tensorboard_service.ListTensorboardRunsResponse() + + client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_tensorboard_runs_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardRunsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardRunsResponse() + ) + + await client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_tensorboard_runs_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardRunsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_tensorboard_runs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_tensorboard_runs_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tensorboard_runs( + tensorboard_service.ListTensorboardRunsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_tensorboard_runs_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardRunsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardRunsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_tensorboard_runs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_tensorboard_runs_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_tensorboard_runs( + tensorboard_service.ListTensorboardRunsRequest(), parent="parent_value", + ) + + +def test_list_tensorboard_runs_pager(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[tensorboard_run.TensorboardRun(),], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_tensorboard_runs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, tensorboard_run.TensorboardRun) for i in results) + + +def test_list_tensorboard_runs_pages(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[tensorboard_run.TensorboardRun(),], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + ), + RuntimeError, + ) + pages = list(client.list_tensorboard_runs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_tensorboard_runs_async_pager(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[tensorboard_run.TensorboardRun(),], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_tensorboard_runs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, tensorboard_run.TensorboardRun) for i in responses) + + +@pytest.mark.asyncio +async def test_list_tensorboard_runs_async_pages(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_runs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[tensorboard_run.TensorboardRun(),], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardRunsResponse( + tensorboard_runs=[ + tensorboard_run.TensorboardRun(), + tensorboard_run.TensorboardRun(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_tensorboard_runs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_tensorboard_run( + transport: str = "grpc", + request_type=tensorboard_service.DeleteTensorboardRunRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardRunRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_tensorboard_run_from_dict(): + test_delete_tensorboard_run(request_type=dict) + + +def test_delete_tensorboard_run_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_run), "__call__" + ) as call: + client.delete_tensorboard_run() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardRunRequest() + + +@pytest.mark.asyncio +async def test_delete_tensorboard_run_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.DeleteTensorboardRunRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardRunRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_run_async_from_dict(): + await test_delete_tensorboard_run_async(request_type=dict) + + +def test_delete_tensorboard_run_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardRunRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_run), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_tensorboard_run_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardRunRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_run), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_tensorboard_run_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_tensorboard_run(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_tensorboard_run_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tensorboard_run( + tensorboard_service.DeleteTensorboardRunRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_run_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_run), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_tensorboard_run(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_tensorboard_run_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_tensorboard_run( + tensorboard_service.DeleteTensorboardRunRequest(), name="name_value", + ) + + +def test_create_tensorboard_time_series( + transport: str = "grpc", + request_type=tensorboard_service.CreateTensorboardTimeSeriesRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value", + display_name="display_name_value", + description="description_value", + value_type=gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + etag="etag_value", + plugin_name="plugin_name_value", + plugin_data=b"plugin_data_blob", + ) + + response = client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_tensorboard_time_series.TensorboardTimeSeries) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert ( + response.value_type + == gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR + ) + + assert response.etag == "etag_value" + + assert response.plugin_name == "plugin_name_value" + + assert response.plugin_data == b"plugin_data_blob" + + +def test_create_tensorboard_time_series_from_dict(): + test_create_tensorboard_time_series(request_type=dict) + + +def test_create_tensorboard_time_series_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_time_series), "__call__" + ) as call: + client.create_tensorboard_time_series() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardTimeSeriesRequest() + + +@pytest.mark.asyncio +async def test_create_tensorboard_time_series_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.CreateTensorboardTimeSeriesRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value", + display_name="display_name_value", + description="description_value", + value_type=gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + etag="etag_value", + plugin_name="plugin_name_value", + plugin_data=b"plugin_data_blob", + ) + ) + + response = await client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.CreateTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_tensorboard_time_series.TensorboardTimeSeries) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert ( + response.value_type + == gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR + ) + + assert response.etag == "etag_value" + + assert response.plugin_name == "plugin_name_value" + + assert response.plugin_data == b"plugin_data_blob" + + +@pytest.mark.asyncio +async def test_create_tensorboard_time_series_async_from_dict(): + await test_create_tensorboard_time_series_async(request_type=dict) + + +def test_create_tensorboard_time_series_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardTimeSeriesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_time_series), "__call__" + ) as call: + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries() + + client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_tensorboard_time_series_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.CreateTensorboardTimeSeriesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_time_series), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_time_series.TensorboardTimeSeries() + ) + + await client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_tensorboard_time_series_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_tensorboard_time_series( + parent="parent_value", + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].tensorboard_time_series == gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ) + + +def test_create_tensorboard_time_series_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tensorboard_time_series( + tensorboard_service.CreateTensorboardTimeSeriesRequest(), + parent="parent_value", + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + ) + + +@pytest.mark.asyncio +async def test_create_tensorboard_time_series_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_time_series.TensorboardTimeSeries() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_tensorboard_time_series( + parent="parent_value", + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].tensorboard_time_series == gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_tensorboard_time_series_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_tensorboard_time_series( + tensorboard_service.CreateTensorboardTimeSeriesRequest(), + parent="parent_value", + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + ) + + +def test_get_tensorboard_time_series( + transport: str = "grpc", + request_type=tensorboard_service.GetTensorboardTimeSeriesRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_time_series.TensorboardTimeSeries( + name="name_value", + display_name="display_name_value", + description="description_value", + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + etag="etag_value", + plugin_name="plugin_name_value", + plugin_data=b"plugin_data_blob", + ) + + response = client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, tensorboard_time_series.TensorboardTimeSeries) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert ( + response.value_type + == tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR + ) + + assert response.etag == "etag_value" + + assert response.plugin_name == "plugin_name_value" + + assert response.plugin_data == b"plugin_data_blob" + + +def test_get_tensorboard_time_series_from_dict(): + test_get_tensorboard_time_series(request_type=dict) + + +def test_get_tensorboard_time_series_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_time_series), "__call__" + ) as call: + client.get_tensorboard_time_series() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardTimeSeriesRequest() + + +@pytest.mark.asyncio +async def test_get_tensorboard_time_series_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.GetTensorboardTimeSeriesRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_time_series.TensorboardTimeSeries( + name="name_value", + display_name="display_name_value", + description="description_value", + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + etag="etag_value", + plugin_name="plugin_name_value", + plugin_data=b"plugin_data_blob", + ) + ) + + response = await client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.GetTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, tensorboard_time_series.TensorboardTimeSeries) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert ( + response.value_type + == tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR + ) + + assert response.etag == "etag_value" + + assert response.plugin_name == "plugin_name_value" + + assert response.plugin_data == b"plugin_data_blob" + + +@pytest.mark.asyncio +async def test_get_tensorboard_time_series_async_from_dict(): + await test_get_tensorboard_time_series_async(request_type=dict) + + +def test_get_tensorboard_time_series_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardTimeSeriesRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_time_series), "__call__" + ) as call: + call.return_value = tensorboard_time_series.TensorboardTimeSeries() + + client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_tensorboard_time_series_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.GetTensorboardTimeSeriesRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_time_series), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_time_series.TensorboardTimeSeries() + ) + + await client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_tensorboard_time_series_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_time_series.TensorboardTimeSeries() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_tensorboard_time_series(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_tensorboard_time_series_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tensorboard_time_series( + tensorboard_service.GetTensorboardTimeSeriesRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_tensorboard_time_series_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_time_series.TensorboardTimeSeries() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_time_series.TensorboardTimeSeries() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_tensorboard_time_series(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_tensorboard_time_series_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_tensorboard_time_series( + tensorboard_service.GetTensorboardTimeSeriesRequest(), name="name_value", + ) + + +def test_update_tensorboard_time_series( + transport: str = "grpc", + request_type=tensorboard_service.UpdateTensorboardTimeSeriesRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value", + display_name="display_name_value", + description="description_value", + value_type=gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + etag="etag_value", + plugin_name="plugin_name_value", + plugin_data=b"plugin_data_blob", + ) + + response = client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_tensorboard_time_series.TensorboardTimeSeries) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert ( + response.value_type + == gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR + ) + + assert response.etag == "etag_value" + + assert response.plugin_name == "plugin_name_value" + + assert response.plugin_data == b"plugin_data_blob" + + +def test_update_tensorboard_time_series_from_dict(): + test_update_tensorboard_time_series(request_type=dict) + + +def test_update_tensorboard_time_series_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_time_series), "__call__" + ) as call: + client.update_tensorboard_time_series() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardTimeSeriesRequest() + + +@pytest.mark.asyncio +async def test_update_tensorboard_time_series_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.UpdateTensorboardTimeSeriesRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value", + display_name="display_name_value", + description="description_value", + value_type=gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + etag="etag_value", + plugin_name="plugin_name_value", + plugin_data=b"plugin_data_blob", + ) + ) + + response = await client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.UpdateTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_tensorboard_time_series.TensorboardTimeSeries) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert ( + response.value_type + == gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR + ) + + assert response.etag == "etag_value" + + assert response.plugin_name == "plugin_name_value" + + assert response.plugin_data == b"plugin_data_blob" + + +@pytest.mark.asyncio +async def test_update_tensorboard_time_series_async_from_dict(): + await test_update_tensorboard_time_series_async(request_type=dict) + + +def test_update_tensorboard_time_series_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardTimeSeriesRequest() + request.tensorboard_time_series.name = "tensorboard_time_series.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_time_series), "__call__" + ) as call: + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries() + + client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_time_series.name=tensorboard_time_series.name/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_tensorboard_time_series_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.UpdateTensorboardTimeSeriesRequest() + request.tensorboard_time_series.name = "tensorboard_time_series.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_time_series), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_time_series.TensorboardTimeSeries() + ) + + await client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_time_series.name=tensorboard_time_series.name/value", + ) in kw["metadata"] + + +def test_update_tensorboard_time_series_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_tensorboard_time_series( + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[ + 0 + ].tensorboard_time_series == gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_tensorboard_time_series_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tensorboard_time_series( + tensorboard_service.UpdateTensorboardTimeSeriesRequest(), + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_tensorboard_time_series_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_tensorboard_time_series.TensorboardTimeSeries() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_tensorboard_time_series.TensorboardTimeSeries() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_tensorboard_time_series( + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[ + 0 + ].tensorboard_time_series == gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_tensorboard_time_series_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_tensorboard_time_series( + tensorboard_service.UpdateTensorboardTimeSeriesRequest(), + tensorboard_time_series=gca_tensorboard_time_series.TensorboardTimeSeries( + name="name_value" + ), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_list_tensorboard_time_series( + transport: str = "grpc", + request_type=tensorboard_service.ListTensorboardTimeSeriesRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardTimeSeriesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListTensorboardTimeSeriesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_tensorboard_time_series_from_dict(): + test_list_tensorboard_time_series(request_type=dict) + + +def test_list_tensorboard_time_series_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + client.list_tensorboard_time_series() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardTimeSeriesRequest() + + +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.ListTensorboardTimeSeriesRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardTimeSeriesResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ListTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTensorboardTimeSeriesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_async_from_dict(): + await test_list_tensorboard_time_series_async(request_type=dict) + + +def test_list_tensorboard_time_series_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardTimeSeriesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + call.return_value = tensorboard_service.ListTensorboardTimeSeriesResponse() + + client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ListTensorboardTimeSeriesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardTimeSeriesResponse() + ) + + await client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_tensorboard_time_series_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardTimeSeriesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_tensorboard_time_series(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_tensorboard_time_series_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tensorboard_time_series( + tensorboard_service.ListTensorboardTimeSeriesRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ListTensorboardTimeSeriesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ListTensorboardTimeSeriesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_tensorboard_time_series(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_tensorboard_time_series( + tensorboard_service.ListTensorboardTimeSeriesRequest(), + parent="parent_value", + ) + + +def test_list_tensorboard_time_series_pager(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_tensorboard_time_series(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, tensorboard_time_series.TensorboardTimeSeries) + for i in results + ) + + +def test_list_tensorboard_time_series_pages(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + ), + RuntimeError, + ) + pages = list(client.list_tensorboard_time_series(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_async_pager(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_tensorboard_time_series(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, tensorboard_time_series.TensorboardTimeSeries) + for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_async_pages(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tensorboard_time_series), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="abc", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[], next_page_token="def", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + ], + next_page_token="ghi", + ), + tensorboard_service.ListTensorboardTimeSeriesResponse( + tensorboard_time_series=[ + tensorboard_time_series.TensorboardTimeSeries(), + tensorboard_time_series.TensorboardTimeSeries(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in ( + await client.list_tensorboard_time_series(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_tensorboard_time_series( + transport: str = "grpc", + request_type=tensorboard_service.DeleteTensorboardTimeSeriesRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_tensorboard_time_series_from_dict(): + test_delete_tensorboard_time_series(request_type=dict) + + +def test_delete_tensorboard_time_series_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_time_series), "__call__" + ) as call: + client.delete_tensorboard_time_series() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardTimeSeriesRequest() + + +@pytest.mark.asyncio +async def test_delete_tensorboard_time_series_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.DeleteTensorboardTimeSeriesRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.DeleteTensorboardTimeSeriesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_time_series_async_from_dict(): + await test_delete_tensorboard_time_series_async(request_type=dict) + + +def test_delete_tensorboard_time_series_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardTimeSeriesRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_time_series), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_tensorboard_time_series_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.DeleteTensorboardTimeSeriesRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_time_series), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_tensorboard_time_series_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_tensorboard_time_series(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_tensorboard_time_series_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tensorboard_time_series( + tensorboard_service.DeleteTensorboardTimeSeriesRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_tensorboard_time_series_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tensorboard_time_series), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_tensorboard_time_series(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_tensorboard_time_series_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_tensorboard_time_series( + tensorboard_service.DeleteTensorboardTimeSeriesRequest(), name="name_value", + ) + + +def test_read_tensorboard_time_series_data( + transport: str = "grpc", + request_type=tensorboard_service.ReadTensorboardTimeSeriesDataRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ReadTensorboardTimeSeriesDataResponse() + + response = client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ReadTensorboardTimeSeriesDataRequest() + + # Establish that the response is the type that we expect. + + assert isinstance( + response, tensorboard_service.ReadTensorboardTimeSeriesDataResponse + ) + + +def test_read_tensorboard_time_series_data_from_dict(): + test_read_tensorboard_time_series_data(request_type=dict) + + +def test_read_tensorboard_time_series_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_time_series_data), "__call__" + ) as call: + client.read_tensorboard_time_series_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ReadTensorboardTimeSeriesDataRequest() + + +@pytest.mark.asyncio +async def test_read_tensorboard_time_series_data_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.ReadTensorboardTimeSeriesDataRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ReadTensorboardTimeSeriesDataResponse() + ) + + response = await client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ReadTensorboardTimeSeriesDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance( + response, tensorboard_service.ReadTensorboardTimeSeriesDataResponse + ) + + +@pytest.mark.asyncio +async def test_read_tensorboard_time_series_data_async_from_dict(): + await test_read_tensorboard_time_series_data_async(request_type=dict) + + +def test_read_tensorboard_time_series_data_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest() + request.tensorboard_time_series = "tensorboard_time_series/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_time_series_data), "__call__" + ) as call: + call.return_value = tensorboard_service.ReadTensorboardTimeSeriesDataResponse() + + client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_time_series=tensorboard_time_series/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_read_tensorboard_time_series_data_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest() + request.tensorboard_time_series = "tensorboard_time_series/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_time_series_data), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ReadTensorboardTimeSeriesDataResponse() + ) + + await client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_time_series=tensorboard_time_series/value", + ) in kw["metadata"] + + +def test_read_tensorboard_time_series_data_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ReadTensorboardTimeSeriesDataResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.read_tensorboard_time_series_data( + tensorboard_time_series="tensorboard_time_series_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_time_series == "tensorboard_time_series_value" + + +def test_read_tensorboard_time_series_data_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.read_tensorboard_time_series_data( + tensorboard_service.ReadTensorboardTimeSeriesDataRequest(), + tensorboard_time_series="tensorboard_time_series_value", + ) + + +@pytest.mark.asyncio +async def test_read_tensorboard_time_series_data_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ReadTensorboardTimeSeriesDataResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ReadTensorboardTimeSeriesDataResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.read_tensorboard_time_series_data( + tensorboard_time_series="tensorboard_time_series_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_time_series == "tensorboard_time_series_value" + + +@pytest.mark.asyncio +async def test_read_tensorboard_time_series_data_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.read_tensorboard_time_series_data( + tensorboard_service.ReadTensorboardTimeSeriesDataRequest(), + tensorboard_time_series="tensorboard_time_series_value", + ) + + +def test_read_tensorboard_blob_data( + transport: str = "grpc", + request_type=tensorboard_service.ReadTensorboardBlobDataRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_blob_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter( + [tensorboard_service.ReadTensorboardBlobDataResponse()] + ) + + response = client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ReadTensorboardBlobDataRequest() + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, tensorboard_service.ReadTensorboardBlobDataResponse) + + +def test_read_tensorboard_blob_data_from_dict(): + test_read_tensorboard_blob_data(request_type=dict) + + +def test_read_tensorboard_blob_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_blob_data), "__call__" + ) as call: + client.read_tensorboard_blob_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ReadTensorboardBlobDataRequest() + + +@pytest.mark.asyncio +async def test_read_tensorboard_blob_data_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.ReadTensorboardBlobDataRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_blob_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[tensorboard_service.ReadTensorboardBlobDataResponse()] + ) + + response = await client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ReadTensorboardBlobDataRequest() + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, tensorboard_service.ReadTensorboardBlobDataResponse) + + +@pytest.mark.asyncio +async def test_read_tensorboard_blob_data_async_from_dict(): + await test_read_tensorboard_blob_data_async(request_type=dict) + + +def test_read_tensorboard_blob_data_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ReadTensorboardBlobDataRequest() + request.time_series = "time_series/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_blob_data), "__call__" + ) as call: + call.return_value = iter( + [tensorboard_service.ReadTensorboardBlobDataResponse()] + ) + + client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "time_series=time_series/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_read_tensorboard_blob_data_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ReadTensorboardBlobDataRequest() + request.time_series = "time_series/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_blob_data), "__call__" + ) as call: + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[tensorboard_service.ReadTensorboardBlobDataResponse()] + ) + + await client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "time_series=time_series/value",) in kw["metadata"] + + +def test_read_tensorboard_blob_data_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_blob_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter( + [tensorboard_service.ReadTensorboardBlobDataResponse()] + ) + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.read_tensorboard_blob_data(time_series="time_series_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].time_series == "time_series_value" + + +def test_read_tensorboard_blob_data_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.read_tensorboard_blob_data( + tensorboard_service.ReadTensorboardBlobDataRequest(), + time_series="time_series_value", + ) + + +@pytest.mark.asyncio +async def test_read_tensorboard_blob_data_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.read_tensorboard_blob_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter( + [tensorboard_service.ReadTensorboardBlobDataResponse()] + ) + + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.read_tensorboard_blob_data( + time_series="time_series_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].time_series == "time_series_value" + + +@pytest.mark.asyncio +async def test_read_tensorboard_blob_data_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.read_tensorboard_blob_data( + tensorboard_service.ReadTensorboardBlobDataRequest(), + time_series="time_series_value", + ) + + +def test_write_tensorboard_run_data( + transport: str = "grpc", + request_type=tensorboard_service.WriteTensorboardRunDataRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.write_tensorboard_run_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.WriteTensorboardRunDataResponse() + + response = client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.WriteTensorboardRunDataRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, tensorboard_service.WriteTensorboardRunDataResponse) + + +def test_write_tensorboard_run_data_from_dict(): + test_write_tensorboard_run_data(request_type=dict) + + +def test_write_tensorboard_run_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.write_tensorboard_run_data), "__call__" + ) as call: + client.write_tensorboard_run_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.WriteTensorboardRunDataRequest() + + +@pytest.mark.asyncio +async def test_write_tensorboard_run_data_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.WriteTensorboardRunDataRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.write_tensorboard_run_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.WriteTensorboardRunDataResponse() + ) + + response = await client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.WriteTensorboardRunDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, tensorboard_service.WriteTensorboardRunDataResponse) + + +@pytest.mark.asyncio +async def test_write_tensorboard_run_data_async_from_dict(): + await test_write_tensorboard_run_data_async(request_type=dict) + + +def test_write_tensorboard_run_data_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.WriteTensorboardRunDataRequest() + request.tensorboard_run = "tensorboard_run/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.write_tensorboard_run_data), "__call__" + ) as call: + call.return_value = tensorboard_service.WriteTensorboardRunDataResponse() + + client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "tensorboard_run=tensorboard_run/value",) in kw[ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_write_tensorboard_run_data_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.WriteTensorboardRunDataRequest() + request.tensorboard_run = "tensorboard_run/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.write_tensorboard_run_data), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.WriteTensorboardRunDataResponse() + ) + + await client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "tensorboard_run=tensorboard_run/value",) in kw[ + "metadata" + ] + + +def test_write_tensorboard_run_data_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.write_tensorboard_run_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.WriteTensorboardRunDataResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.write_tensorboard_run_data( + tensorboard_run="tensorboard_run_value", + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="tensorboard_time_series_id_value" + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_run == "tensorboard_run_value" + + assert args[0].time_series_data == [ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="tensorboard_time_series_id_value" + ) + ] + + +def test_write_tensorboard_run_data_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.write_tensorboard_run_data( + tensorboard_service.WriteTensorboardRunDataRequest(), + tensorboard_run="tensorboard_run_value", + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="tensorboard_time_series_id_value" + ) + ], + ) + + +@pytest.mark.asyncio +async def test_write_tensorboard_run_data_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.write_tensorboard_run_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.WriteTensorboardRunDataResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.WriteTensorboardRunDataResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.write_tensorboard_run_data( + tensorboard_run="tensorboard_run_value", + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="tensorboard_time_series_id_value" + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_run == "tensorboard_run_value" + + assert args[0].time_series_data == [ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="tensorboard_time_series_id_value" + ) + ] + + +@pytest.mark.asyncio +async def test_write_tensorboard_run_data_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.write_tensorboard_run_data( + tensorboard_service.WriteTensorboardRunDataRequest(), + tensorboard_run="tensorboard_run_value", + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="tensorboard_time_series_id_value" + ) + ], + ) + + +def test_export_tensorboard_time_series_data( + transport: str = "grpc", + request_type=tensorboard_service.ExportTensorboardTimeSeriesDataRequest, +): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + next_page_token="next_page_token_value", + ) + + response = client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ExportTensorboardTimeSeriesDataRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ExportTensorboardTimeSeriesDataPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_export_tensorboard_time_series_data_from_dict(): + test_export_tensorboard_time_series_data(request_type=dict) + + +def test_export_tensorboard_time_series_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + client.export_tensorboard_time_series_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ExportTensorboardTimeSeriesDataRequest() + + +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_async( + transport: str = "grpc_asyncio", + request_type=tensorboard_service.ExportTensorboardTimeSeriesDataRequest, +): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == tensorboard_service.ExportTensorboardTimeSeriesDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ExportTensorboardTimeSeriesDataAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_async_from_dict(): + await test_export_tensorboard_time_series_data_async(request_type=dict) + + +def test_export_tensorboard_time_series_data_field_headers(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest() + request.tensorboard_time_series = "tensorboard_time_series/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + call.return_value = ( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse() + ) + + client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_time_series=tensorboard_time_series/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_field_headers_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest() + request.tensorboard_time_series = "tensorboard_time_series/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse() + ) + + await client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tensorboard_time_series=tensorboard_time_series/value", + ) in kw["metadata"] + + +def test_export_tensorboard_time_series_data_flattened(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = ( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse() + ) + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_tensorboard_time_series_data( + tensorboard_time_series="tensorboard_time_series_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_time_series == "tensorboard_time_series_value" + + +def test_export_tensorboard_time_series_data_flattened_error(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_tensorboard_time_series_data( + tensorboard_service.ExportTensorboardTimeSeriesDataRequest(), + tensorboard_time_series="tensorboard_time_series_value", + ) + + +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_flattened_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = ( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse() + ) + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_tensorboard_time_series_data( + tensorboard_time_series="tensorboard_time_series_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].tensorboard_time_series == "tensorboard_time_series_value" + + +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_flattened_error_async(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_tensorboard_time_series_data( + tensorboard_service.ExportTensorboardTimeSeriesDataRequest(), + tensorboard_time_series="tensorboard_time_series_value", + ) + + +def test_export_tensorboard_time_series_data_pager(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + next_page_token="abc", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[], next_page_token="def", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[tensorboard_data.TimeSeriesDataPoint(),], + next_page_token="ghi", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tensorboard_time_series", ""),) + ), + ) + pager = client.export_tensorboard_time_series_data(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, tensorboard_data.TimeSeriesDataPoint) for i in results) + + +def test_export_tensorboard_time_series_data_pages(): + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + next_page_token="abc", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[], next_page_token="def", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[tensorboard_data.TimeSeriesDataPoint(),], + next_page_token="ghi", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + ), + RuntimeError, + ) + pages = list(client.export_tensorboard_time_series_data(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_async_pager(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + next_page_token="abc", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[], next_page_token="def", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[tensorboard_data.TimeSeriesDataPoint(),], + next_page_token="ghi", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + ), + RuntimeError, + ) + async_pager = await client.export_tensorboard_time_series_data(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, tensorboard_data.TimeSeriesDataPoint) for i in responses + ) + + +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_async_pages(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_tensorboard_time_series_data), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + next_page_token="abc", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[], next_page_token="def", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[tensorboard_data.TimeSeriesDataPoint(),], + next_page_token="ghi", + ), + tensorboard_service.ExportTensorboardTimeSeriesDataResponse( + time_series_data_points=[ + tensorboard_data.TimeSeriesDataPoint(), + tensorboard_data.TimeSeriesDataPoint(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in ( + await client.export_tensorboard_time_series_data(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.TensorboardServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.TensorboardServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TensorboardServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.TensorboardServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TensorboardServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.TensorboardServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = TensorboardServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.TensorboardServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.TensorboardServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TensorboardServiceGrpcTransport, + transports.TensorboardServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = TensorboardServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.TensorboardServiceGrpcTransport,) + + +def test_tensorboard_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.TensorboardServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_tensorboard_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports.TensorboardServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.TensorboardServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_tensorboard", + "get_tensorboard", + "update_tensorboard", + "list_tensorboards", + "delete_tensorboard", + "create_tensorboard_experiment", + "get_tensorboard_experiment", + "update_tensorboard_experiment", + "list_tensorboard_experiments", + "delete_tensorboard_experiment", + "create_tensorboard_run", + "get_tensorboard_run", + "update_tensorboard_run", + "list_tensorboard_runs", + "delete_tensorboard_run", + "create_tensorboard_time_series", + "get_tensorboard_time_series", + "update_tensorboard_time_series", + "list_tensorboard_time_series", + "delete_tensorboard_time_series", + "read_tensorboard_time_series_data", + "read_tensorboard_blob_data", + "write_tensorboard_run_data", + "export_tensorboard_time_series_data", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_tensorboard_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports.TensorboardServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.TensorboardServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_tensorboard_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports.TensorboardServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.TensorboardServiceTransport() + adc.assert_called_once() + + +def test_tensorboard_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + TensorboardServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_tensorboard_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.TensorboardServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TensorboardServiceGrpcTransport, + transports.TensorboardServiceGrpcAsyncIOTransport, + ], +) +def test_tensorboard_service_grpc_transport_client_cert_source_for_mtls( + transport_class, +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_tensorboard_service_host_no_port(): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_tensorboard_service_host_with_port(): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_tensorboard_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TensorboardServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_tensorboard_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TensorboardServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.TensorboardServiceGrpcTransport, + transports.TensorboardServiceGrpcAsyncIOTransport, + ], +) +def test_tensorboard_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.TensorboardServiceGrpcTransport, + transports.TensorboardServiceGrpcAsyncIOTransport, + ], +) +def test_tensorboard_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_tensorboard_service_grpc_lro_client(): + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_tensorboard_service_grpc_lro_async_client(): + client = TensorboardServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_tensorboard_path(): + project = "squid" + location = "clam" + tensorboard = "whelk" + + expected = "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( + project=project, location=location, tensorboard=tensorboard, + ) + actual = TensorboardServiceClient.tensorboard_path(project, location, tensorboard) + assert expected == actual + + +def test_parse_tensorboard_path(): + expected = { + "project": "octopus", + "location": "oyster", + "tensorboard": "nudibranch", + } + path = TensorboardServiceClient.tensorboard_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_tensorboard_path(path) + assert expected == actual + + +def test_tensorboard_experiment_path(): + project = "cuttlefish" + location = "mussel" + tensorboard = "winkle" + experiment = "nautilus" + + expected = "projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}".format( + project=project, + location=location, + tensorboard=tensorboard, + experiment=experiment, + ) + actual = TensorboardServiceClient.tensorboard_experiment_path( + project, location, tensorboard, experiment + ) + assert expected == actual + + +def test_parse_tensorboard_experiment_path(): + expected = { + "project": "scallop", + "location": "abalone", + "tensorboard": "squid", + "experiment": "clam", + } + path = TensorboardServiceClient.tensorboard_experiment_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_tensorboard_experiment_path(path) + assert expected == actual + + +def test_tensorboard_run_path(): + project = "whelk" + location = "octopus" + tensorboard = "oyster" + experiment = "nudibranch" + run = "cuttlefish" + + expected = "projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}".format( + project=project, + location=location, + tensorboard=tensorboard, + experiment=experiment, + run=run, + ) + actual = TensorboardServiceClient.tensorboard_run_path( + project, location, tensorboard, experiment, run + ) + assert expected == actual + + +def test_parse_tensorboard_run_path(): + expected = { + "project": "mussel", + "location": "winkle", + "tensorboard": "nautilus", + "experiment": "scallop", + "run": "abalone", + } + path = TensorboardServiceClient.tensorboard_run_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_tensorboard_run_path(path) + assert expected == actual + + +def test_tensorboard_time_series_path(): + project = "squid" + location = "clam" + tensorboard = "whelk" + experiment = "octopus" + run = "oyster" + time_series = "nudibranch" + + expected = "projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}/timeSeries/{time_series}".format( + project=project, + location=location, + tensorboard=tensorboard, + experiment=experiment, + run=run, + time_series=time_series, + ) + actual = TensorboardServiceClient.tensorboard_time_series_path( + project, location, tensorboard, experiment, run, time_series + ) + assert expected == actual + + +def test_parse_tensorboard_time_series_path(): + expected = { + "project": "cuttlefish", + "location": "mussel", + "tensorboard": "winkle", + "experiment": "nautilus", + "run": "scallop", + "time_series": "abalone", + } + path = TensorboardServiceClient.tensorboard_time_series_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_tensorboard_time_series_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = TensorboardServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = TensorboardServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = TensorboardServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = TensorboardServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = TensorboardServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = TensorboardServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = TensorboardServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = TensorboardServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = TensorboardServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = TensorboardServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = TensorboardServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.TensorboardServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = TensorboardServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.TensorboardServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = TensorboardServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) From 5accf1d26c25dbc34f329a33797fcc5fa2a1f315 Mon Sep 17 00:00:00 2001 From: jialuzh <35091833+jialuzh@users.noreply.github.com> Date: Tue, 4 May 2021 15:17:32 -0700 Subject: [PATCH 18/36] feat: remove MB type, enforce metrics value type and add experiment description support (#361) --- google/cloud/aiplatform/initializer.py | 14 +++- google/cloud/aiplatform/metadata/constants.py | 2 +- google/cloud/aiplatform/metadata/metadata.py | 37 +++++++++-- google/cloud/aiplatform/metadata/resource.py | 10 +++ tests/unit/aiplatform/test_initializer.py | 20 +++++- tests/unit/aiplatform/test_metadata.py | 65 +++++++++++++++++++ 6 files changed, 138 insertions(+), 10 deletions(-) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index eecbac61c7..f6ed877de7 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -56,6 +56,7 @@ def init( project: Optional[str] = None, location: Optional[str] = None, experiment: Optional[str] = None, + experiment_description: Optional[str] = None, staging_bucket: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, encryption_spec_key_name: Optional[str] = None, @@ -65,8 +66,9 @@ def init( Args: project (str): The default project to use when making API calls. location (str): The default location to use when making API calls. If not - set defaults to us-central-1 - experiment (str): The experiment name + set defaults to us-central-1. + experiment (str): The experiment name. + experiment_description (str): The description of the experiment. staging_bucket (str): The default staging bucket to use to stage artifacts when making API calls. In the form gs://... credentials (google.auth.credentials.Credentials): The default custom @@ -96,7 +98,13 @@ def init( utils.validate_region(location) self._location = location if experiment: - metadata.metadata_service.set_experiment(experiment) + metadata.metadata_service.set_experiment( + experiment=experiment, description=experiment_description + ) + if experiment_description and experiment is None: + raise ValueError( + "Experiment name needs to be set in `init` in order to add experiment descriptions." + ) if staging_bucket: self._staging_bucket = staging_bucket if credentials: diff --git a/google/cloud/aiplatform/metadata/constants.py b/google/cloud/aiplatform/metadata/constants.py index 7db87cc222..62e7d6e075 100644 --- a/google/cloud/aiplatform/metadata/constants.py +++ b/google/cloud/aiplatform/metadata/constants.py @@ -31,4 +31,4 @@ # The EXPERIMENT_METADATA is needed until we support context deletion in backend service. # TODO: delete EXPERIMENT_METADATA once backend supports context deletion. -EXPERIMENT_METADATA = {"experiment_deleted": False, "experiment_type": "MB"} +EXPERIMENT_METADATA = {"experiment_deleted": False} diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index e31350d466..919eff8619 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -52,20 +52,21 @@ def run_name(self) -> Optional[str]: return self._run.display_name return None - def set_experiment(self, experiment: str): + def set_experiment(self, experiment: str, description: Optional[str] = None): """Setup a experiment to current session. Args: experiment (str): Required. Name of the experiment to assign current session with. - Raises: - ValueError if a context with the same name as the experiment is create but with a different schema. + description (str): + Optional. Description of an experiment. """ _MetadataStore.get_or_create() context = _Context.get_or_create( resource_id=experiment, display_name=experiment, + description=description, schema_title=constants.SYSTEM_EXPERIMENT, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], metadata=constants.EXPERIMENT_METADATA, @@ -75,6 +76,10 @@ def set_experiment(self, experiment: str): f"Experiment name {experiment} has been used to create other type of resources " f"({context.schema_title}) in this MetadataStore, please choose a different experiment name." ) + + if description and context.description != description: + context.update(metadata=context.metadata, description=description) + self._experiment = context def start_run(self, run: str): @@ -145,15 +150,19 @@ def log_params(self, params: Dict[str, Union[float, int, str]]): ) execution.update(metadata=params) - def log_metrics(self, metrics: Dict[str, Union[str, float, int]]): + def log_metrics(self, metrics: Dict[str, Union[float, int]]): """Log single or multiple Metrics with specified key and value pairs. Args: metrics (Dict): - Required. Metrics key/value pairs. + Required. Metrics key/value pairs. Only flot and int are supported format for value. + Raises: + TypeError if value contains unsupported types. + ValueError if Experiment or Run is not set. """ self._validate_experiment_and_run(method_name="log_metrics") + self._validate_metrics_value_type(metrics) # query the latest metrics artifact resource before logging. artifact = _Artifact.get_or_create( resource_id=self._metrics.name, @@ -248,6 +257,24 @@ def _validate_experiment_and_run(self, method_name: str): f"No run set. Make sure to call aiplatform.start_run('my-run') before trying to {method_name}. " ) + @staticmethod + def _validate_metrics_value_type(metrics: Dict[str, Union[float, int]]): + """Verify that metrics value are with supported types. + + Args: + metrics (Dict): + Required. Metrics key/value pairs. Only flot and int are supported format for value. + Raises: + TypeError if value contains unsupported types. + """ + + for key, value in metrics.items(): + if isinstance(value, int) or isinstance(value, float): + continue + raise TypeError( + f"metrics contain unsupported value types. key: {key}; value: {value}; type: {type(value)}" + ) + @staticmethod def _get_experiment_or_pipeline_resource_name( name: str, source: str, expected_schema: str diff --git a/google/cloud/aiplatform/metadata/resource.py b/google/cloud/aiplatform/metadata/resource.py index 03266bafe3..11f03b7af1 100644 --- a/google/cloud/aiplatform/metadata/resource.py +++ b/google/cloud/aiplatform/metadata/resource.py @@ -104,6 +104,10 @@ def metadata(self) -> Dict: def schema_title(self) -> str: return self._gca_resource.schema_title + @property + def description(self) -> str: + return self._gca_resource.description + @classmethod def get_or_create( cls, @@ -182,6 +186,7 @@ def get_or_create( def update( self, metadata: Dict, + description: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ): """Updates an existing Metadata resource with new metadata. @@ -189,6 +194,8 @@ def update( Args: metadata (Dict): Required. metadata contains the updated metadata information. + description (str): + Optional. Description describes the resource to be updated. credentials (auth_credentials.Credentials): Custom credentials to use to update this resource. Overrides credentials set in aiplatform.init. @@ -200,6 +207,9 @@ def update( gca_resource.metadata.update(metadata) else: gca_resource.metadata = metadata + if description: + gca_resource.description = description + api_client = self._instantiate_client(credentials=credentials) update_gca_resource = self._update_resource( diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index fc0c33dbca..7e65c99b4c 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -39,6 +39,7 @@ _TEST_LOCATION_2 = "europe-west4" _TEST_INVALID_LOCATION = "test-invalid-location" _TEST_EXPERIMENT = "test-experiment" +_TEST_DESCRIPTION = "test-description" _TEST_STAGING_BUCKET = "test-bucket" @@ -74,7 +75,24 @@ def test_init_location_with_invalid_location_raises(self): @patch.object(metadata_service, "set_experiment") def test_init_experiment_sets_experiment(self, set_experiment_mock): initializer.global_config.init(experiment=_TEST_EXPERIMENT) - set_experiment_mock.assert_called_once_with(_TEST_EXPERIMENT) + set_experiment_mock.assert_called_once_with( + experiment=_TEST_EXPERIMENT, description=None + ) + + @patch.object(metadata_service, "set_experiment") + def test_init_experiment_sets_experiment_with_description( + self, set_experiment_mock + ): + initializer.global_config.init( + experiment=_TEST_EXPERIMENT, experiment_description=_TEST_DESCRIPTION + ) + set_experiment_mock.assert_called_once_with( + experiment=_TEST_EXPERIMENT, description=_TEST_DESCRIPTION + ) + + def test_init_experiment_description_fail_without_experiment(self): + with pytest.raises(ValueError): + initializer.global_config.init(experiment_description=_TEST_DESCRIPTION) def test_init_staging_bucket_sets_staging_bucket(self): initializer.global_config.init(staging_bucket=_TEST_STAGING_BUCKET) diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py index bdb1c7bc39..26e297426b 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata.py @@ -49,6 +49,8 @@ f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" ) _TEST_EXPERIMENT = "test-experiment" +_TEST_EXPERIMENT_DESCRIPTION = "test-experiment-description" +_TEST_OTHER_EXPERIMENT_DESCRIPTION = "test-other-experiment-description" _TEST_PIPELINE = _TEST_EXPERIMENT _TEST_RUN = "run-1" _TEST_OTHER_RUN = "run-2" @@ -110,6 +112,7 @@ def get_context_mock(): get_context_mock.return_value = GapicContext( name=_TEST_CONTEXT_NAME, display_name=_TEST_EXPERIMENT, + description=_TEST_EXPERIMENT_DESCRIPTION, schema_title=constants.SYSTEM_EXPERIMENT, schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], metadata=constants.EXPERIMENT_METADATA, @@ -156,6 +159,20 @@ def get_context_not_found_mock(): yield get_context_not_found_mock +@pytest.fixture +def update_context_mock(): + with patch.object(MetadataServiceClient, "update_context") as update_context_mock: + update_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + description=_TEST_OTHER_EXPERIMENT_DESCRIPTION, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + yield update_context_mock + + @pytest.fixture def add_context_artifacts_and_executions_mock(): with patch.object( @@ -347,6 +364,40 @@ def test_init_experiment_with_existing_metadataStore_and_context( get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + def test_init_experiment_with_existing_description( + self, get_metadata_store_mock, get_context_mock + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + experiment_description=_TEST_EXPERIMENT_DESCRIPTION, + ) + + get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + def test_init_experiment_without_existing_description(self, update_context_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + experiment_description=_TEST_OTHER_EXPERIMENT_DESCRIPTION, + ) + + experiment_context = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + description=_TEST_OTHER_EXPERIMENT_DESCRIPTION, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + + update_context_mock.assert_called_once_with(context=experiment_context) + @pytest.mark.usefixtures("get_metadata_store_mock") @pytest.mark.usefixtures("get_context_wrong_schema_mock") def test_init_experiment_wrong_schema(self): @@ -477,6 +528,20 @@ def test_log_metrics( update_artifact_mock.assert_called_once_with(artifact=updated_artifact) + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_log_metrics_string_value_raise_error(self): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.start_run(_TEST_RUN) + with pytest.raises(TypeError): + aiplatform.log_metrics({"test": "string"}) + # TODO: remove skip once koroko test would install extra required packages. @pytest.mark.skip( reason="Temporarily skip this test as extra required package are not installed in current setup" From 20320979436475bd276f0514e0e44400e0040327 Mon Sep 17 00:00:00 2001 From: thehardikv <78449654+thehardikv@users.noreply.github.com> Date: Wed, 5 May 2021 16:41:45 -0700 Subject: [PATCH 19/36] feat: Added AutoMLForecastingTrainingJob and tests (#237) --- google/cloud/aiplatform/__init__.py | 3 + google/cloud/aiplatform/datasets/__init__.py | 2 + .../cloud/aiplatform/datasets/_datasources.py | 5 + .../datasets/time_series_dataset.py | 134 +++++ google/cloud/aiplatform/schema.py | 2 + google/cloud/aiplatform/training_jobs.py | 480 ++++++++++++++++- .../test_automl_forecasting_training_jobs.py | 488 ++++++++++++++++++ 7 files changed, 1100 insertions(+), 14 deletions(-) create mode 100644 google/cloud/aiplatform/datasets/time_series_dataset.py create mode 100644 tests/unit/aiplatform/test_automl_forecasting_training_jobs.py diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 58eb824454..cc7e119603 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -33,6 +33,7 @@ CustomContainerTrainingJob, CustomPythonPackageTrainingJob, AutoMLTabularTrainingJob, + AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTextTrainingJob, AutoMLVideoTrainingJob, @@ -52,6 +53,7 @@ "init", "AutoMLImageTrainingJob", "AutoMLTabularTrainingJob", + "AutoMLForecastingTrainingJob", "AutoMLTextTrainingJob", "AutoMLVideoTrainingJob", "BatchPredictionJob", @@ -63,5 +65,6 @@ "Model", "TabularDataset", "TextDataset", + "TimeSeriesDataset", "VideoDataset", ) diff --git a/google/cloud/aiplatform/datasets/__init__.py b/google/cloud/aiplatform/datasets/__init__.py index 57e2bad45d..b297530955 100644 --- a/google/cloud/aiplatform/datasets/__init__.py +++ b/google/cloud/aiplatform/datasets/__init__.py @@ -17,6 +17,7 @@ from google.cloud.aiplatform.datasets.dataset import _Dataset from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset +from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset from google.cloud.aiplatform.datasets.image_dataset import ImageDataset from google.cloud.aiplatform.datasets.text_dataset import TextDataset from google.cloud.aiplatform.datasets.video_dataset import VideoDataset @@ -25,6 +26,7 @@ __all__ = ( "_Dataset", "TabularDataset", + "TimeSeriesDataset", "ImageDataset", "TextDataset", "VideoDataset", diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py index eefd1b04fd..1221429258 100644 --- a/google/cloud/aiplatform/datasets/_datasources.py +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -224,6 +224,11 @@ def create_datasource( raise ValueError("tabular dataset does not support data import.") return TabularDatasource(gcs_source, bq_source) + if metadata_schema_uri == schema.dataset.metadata.time_series: + if import_schema_uri: + raise ValueError("time series dataset does not support data import.") + return TabularDatasource(gcs_source, bq_source) + if not import_schema_uri and not gcs_source: return NonTabularDatasource() elif import_schema_uri and gcs_source: diff --git a/google/cloud/aiplatform/datasets/time_series_dataset.py b/google/cloud/aiplatform/datasets/time_series_dataset.py new file mode 100644 index 0000000000..92d8e60c37 --- /dev/null +++ b/google/cloud/aiplatform/datasets/time_series_dataset.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class TimeSeriesDataset(datasets._Dataset): + """Managed time series dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.time_series, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "TimeSeriesDataset": + """Creates a new tabular dataset. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + time_series_dataset (TimeSeriesDataset): + Instantiated representation of the managed time series dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.time_series + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + gcs_source=gcs_source, + bq_source=bq_source, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + def import_data(self): + raise NotImplementedError( + f"{self.__class__.__name__} class does not support 'import_data'" + ) diff --git a/google/cloud/aiplatform/schema.py b/google/cloud/aiplatform/schema.py index 04d2f026a1..6b2a3d7d66 100644 --- a/google/cloud/aiplatform/schema.py +++ b/google/cloud/aiplatform/schema.py @@ -22,6 +22,7 @@ class training_job: class definition: custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml" automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml" + automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml" automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml" automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml" automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml" @@ -37,6 +38,7 @@ class metadata: tabular = ( "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml" ) + time_series = "gs://google-cloud-aiplatform/schema/dataset/metadata/time_series_1.0.0.yaml" image = "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml" text = "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml" video = "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml" diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 220a34637e..572e3ad0ae 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -279,7 +279,7 @@ def _create_input_data_config( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. gcs_destination_uri_prefix (str): Optional. The Google Cloud Storage location. @@ -320,12 +320,12 @@ def _create_input_data_config( # Create predefined split spec predefined_split = None if predefined_split_column_name: - if ( - dataset._gca_resource.metadata_schema_uri - != schema.dataset.metadata.tabular + if dataset._gca_resource.metadata_schema_uri not in ( + schema.dataset.metadata.tabular, + schema.dataset.metadata.time_series, ): raise ValueError( - "A pre-defined split may only be used with a tabular Dataset" + "A pre-defined split may only be used with a tabular or time series Dataset" ) predefined_split = gca_training_pipeline.PredefinedSplit( @@ -438,7 +438,7 @@ def _run_job( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. model (~.model.Model): Optional. Describes the Model that may be uploaded (via [ModelService.UploadMode][]) by this TrainingPipeline. The @@ -1904,7 +1904,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2026,7 +2026,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2421,7 +2421,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2537,7 +2537,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2759,7 +2759,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. weight_column (str): Optional. Name of the column that should be used as the weight column. Higher values in this column give more importance to the row @@ -2874,7 +2874,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. weight_column (str): Optional. Name of the column that should be used as the weight column. Higher values in this column give more importance to the row @@ -2960,6 +2960,458 @@ def _model_upload_fail_string(self) -> str: ) +class AutoMLForecastingTrainingJob(_TrainingJob): + _supported_training_schemas = (schema.training_job.definition.automl_forecasting,) + + def __init__( + self, + display_name: str, + optimization_objective: Optional[str] = None, + column_transformations: Optional[Union[Dict, List[Dict]]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Constructs a AutoML Forecasting Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + optimization_objective (str): + Optional. Objective function the model is to be optimized towards. + The training process creates a Model that optimizes the value of the objective + function over the validation set. The supported optimization objectives: + "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE). + "minimize-mae" - Minimize mean-absolute error (MAE). + "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE). + "minimize-rmspe" - Minimize root-mean-squared percentage error (RMSPE). + "minimize-wape-mae" - Minimize the combination of weighted absolute percentage error (WAPE) + and mean-absolute-error (MAE). + "minimize-quantile-loss" - Minimize the quantile loss at the defined quantiles. + (Set this objective to build quantile forecasts.) + column_transformations (Optional[Union[Dict, List[Dict]]]): + Optional. Transformations to apply to the input columns (i.e. columns other + than the targetColumn). Each transformation may produce multiple + result values from the column's value, and all are used for training. + When creating transformation for BigQuery Struct column, the column + should be flattened using "." as the delimiter. + If an input column has no transformations on it, such a column is + ignored by the training, except for the targetColumn, which should have + no transformations defined on. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + ) + self._column_transformations = column_transformations + self._optimization_objective = optimization_objective + + def run( + self, + dataset: datasets.TimeSeriesDataset, + target_column: str, + time_column: str, + time_series_identifier_column: str, + unavailable_at_forecast_columns: List[str], + available_at_forecast_columns: List[str], + forecast_horizon: int, + data_granularity_unit: str, + data_granularity_count: int, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + time_series_attribute_columns: Optional[List[str]] = None, + context_window: Optional[int] = None, + export_evaluated_data_items: bool = False, + export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_override_destination: bool = False, + quantiles: Optional[List[float]] = None, + validation_options: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + The training data splits are set by default: Roughly 80% will be used for training, + 10% for validation, and 10% for test. + + Args: + dataset (datasets.Dataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For time series Datasets, all their data is exported to + training, to pick and choose from. + target_column (str): + Required. Name of the column that the Model is to predict values for. + time_column (str): + Required. Name of the column that identifies time order in the time series. + time_series_identifier_column (str): + Required. Name of the column that identifies the time series. + unavailable_at_forecast_columns (List[str]): + Required. Column names of columns that are unavailable at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is unknown before the forecast + (e.g. population of a city in a given year, or weather on a given day). + available_at_forecast_columns (List[str]): + Required. Column names of columns that are available at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is known at forecast. + forecast_horizon: (int): + Required. The amount of time into the future for which forecasted values for the target are + returned. Expressed in number of units defined by the [data_granularity_unit] and + [data_granularity_count] field. Inclusive. + data_granularity_unit (str): + Required. The data granularity unit. Accepted values are ``minute``, + ``hour``, ``day``, ``week``, ``month``, ``year``. + data_granularity_count (int): + Required. The number of data granularity units between data points in the training + data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other + values of [data_granularity_unit], must be 1. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``TRAIN``, + ``VALIDATE``, ``TEST``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + time_series_attribute_columns (List[str]): + Optional. Column names that should be used as attribute columns. + Each column is constant within a time series. + context_window (int): + Optional. The amount of time into the past training and prediction data is used for + model training and prediction respectively. Expressed in number of units defined by the + [data_granularity_unit] and [data_granularity_count] fields. When not provided uses the + default value of 0 which means the model sets each series context window to be 0 (also + known as "cold start"). Inclusive. + export_evaluated_data_items (bool): + Whether to export the test set predictions to a BigQuery table. + If False, then the export is not performed. + export_evaluated_data_items_bigquery_destination_uri (string): + Optional. URI of desired destination BigQuery table for exported test set predictions. + + Expected format: + ``bq://::`` + + If not specified, then results are exported to the following auto-created BigQuery + table: + ``:export_evaluated_examples__.evaluated_examples`` + + Applies only if [export_evaluated_data_items] is True. + export_evaluated_data_items_override_destination (bool): + Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri], + if the table exists, for exported test set predictions. If False, and the + table exists, then the training job will fail. + + Applies only if [export_evaluated_data_items] is True and + [export_evaluated_data_items_bigquery_destination_uri] is specified. + quantiles (List[float]): + Quantiles to use for the `minizmize-quantile-loss` + [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in + this case. + + Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive. + Each quantile must be unique. + validation_options (str): + Validation options for the data validation component. The available options are: + "fail-pipeline" - (default), will validate against the validation and fail the pipeline + if it fails. + "ignore-validation" - ignore the results of the validation and continue the pipeline + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError if Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError( + "AutoML Forecasting Training is already scheduled to run." + ) + + if self._has_run: + raise RuntimeError("AutoML Forecasting Training has already run.") + + return self._run( + dataset=dataset, + target_column=target_column, + time_column=time_column, + time_series_identifier_column=time_series_identifier_column, + unavailable_at_forecast_columns=unavailable_at_forecast_columns, + available_at_forecast_columns=available_at_forecast_columns, + forecast_horizon=forecast_horizon, + data_granularity_unit=data_granularity_unit, + data_granularity_count=data_granularity_count, + predefined_split_column_name=predefined_split_column_name, + weight_column=weight_column, + time_series_attribute_columns=time_series_attribute_columns, + context_window=context_window, + budget_milli_node_hours=budget_milli_node_hours, + export_evaluated_data_items=export_evaluated_data_items, + export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri, + export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination, + quantiles=quantiles, + validation_options=validation_options, + model_display_name=model_display_name, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.TimeSeriesDataset, + target_column: str, + time_column: str, + time_series_identifier_column: str, + unavailable_at_forecast_columns: List[str], + available_at_forecast_columns: List[str], + forecast_horizon: int, + data_granularity_unit: str, + data_granularity_count: int, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + time_series_attribute_columns: Optional[List[str]] = None, + context_window: Optional[int] = None, + export_evaluated_data_items: bool = False, + export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_override_destination: bool = False, + quantiles: Optional[List[float]] = None, + validation_options: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + The training data splits are set by default: Roughly 80% will be used for training, + 10% for validation, and 10% for test. + + Args: + dataset (datasets.Dataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For time series Datasets, all their data is exported to + training, to pick and choose from. + target_column (str): + Required. Name of the column that the Model is to predict values for. + time_column (str): + Required. Name of the column that identifies time order in the time series. + time_series_identifier_column (str): + Required. Name of the column that identifies the time series. + unavailable_at_forecast_columns (List[str]): + Required. Column names of columns that are unavailable at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is unknown before the forecast + (e.g. population of a city in a given year, or weather on a given day). + available_at_forecast_columns (List[str]): + Required. Column names of columns that are available at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is known at forecast. + forecast_horizon: (int): + Required. The amount of time into the future for which forecasted values for the target are + returned. Expressed in number of units defined by the [data_granularity_unit] and + [data_granularity_count] field. Inclusive. + data_granularity_unit (str): + Required. The data granularity unit. Accepted values are ``minute``, + ``hour``, ``day``, ``week``, ``month``, ``year``. + data_granularity_count (int): + Required. The number of data granularity units between data points in the training + data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other + values of [data_granularity_unit], must be 1. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``TRAIN``, + ``VALIDATE``, ``TEST``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + time_series_attribute_columns (List[str]): + Optional. Column names that should be used as attribute columns. + Each column is constant within a time series. + context_window (int): + Optional. The number of periods offset into the past to restrict past sequence, where each + period is one unit of granularity as defined by [period]. When not provided uses the + default value of 0 which means the model sets each series historical window to be 0 (also + known as "cold start"). Inclusive. + export_evaluated_data_items (bool): + Whether to export the test set predictions to a BigQuery table. + If False, then the export is not performed. + export_evaluated_data_items_bigquery_destination_uri (string): + Optional. URI of desired destination BigQuery table for exported test set predictions. + + Expected format: + ``bq://::
`` + + If not specified, then results are exported to the following auto-created BigQuery + table: + ``:export_evaluated_examples__.evaluated_examples`` + + Applies only if [export_evaluated_data_items] is True. + export_evaluated_data_items_override_destination (bool): + Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri], + if the table exists, for exported test set predictions. If False, and the + table exists, then the training job will fail. + + Applies only if [export_evaluated_data_items] is True and + [export_evaluated_data_items_bigquery_destination_uri] is specified. + quantiles (List[float]): + Quantiles to use for the `minizmize-quantile-loss` + [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in + this case. + + Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive. + Each quantile must be unique. + validation_options (str): + Validation options for the data validation component. The available options are: + "fail-pipeline" - (default), will validate against the validation and fail the pipeline + if it fails. + "ignore-validation" - ignore the results of the validation and continue the pipeline + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + training_task_definition = schema.training_job.definition.automl_forecasting + + training_task_inputs_dict = { + # required inputs + "targetColumn": target_column, + "timeColumn": time_column, + "timeSeriesIdentifierColumn": time_series_identifier_column, + "timeSeriesAttributeColumns": time_series_attribute_columns, + "unavailableAtForecastColumns": unavailable_at_forecast_columns, + "availableAtForecastColumns": available_at_forecast_columns, + "forecastHorizon": forecast_horizon, + "dataGranularity": { + "unit": data_granularity_unit, + "quantity": data_granularity_count, + }, + "transformations": self._column_transformations, + "trainBudgetMilliNodeHours": budget_milli_node_hours, + # optional inputs + "weightColumn": weight_column, + "contextWindow": context_window, + "quantiles": quantiles, + "validationOptions": validation_options, + "optimizationObjective": self._optimization_objective, + } + + final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri + if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith( + "bq://" + ): + final_export_eval_bq_uri = f"bq://{final_export_eval_bq_uri}" + + if export_evaluated_data_items: + training_task_inputs_dict["exportEvaluatedDataItemsConfig"] = { + "destinationBigqueryUri": final_export_eval_bq_uri, + "overrideExistingTable": export_evaluated_data_items_override_destination, + } + + if model_display_name is None: + model_display_name = self._display_name + + model = gca_model.Model(display_name=model_display_name) + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=0.8, + validation_fraction_split=0.1, + test_fraction_split=0.1, + predefined_split_column_name=predefined_split_column_name, + model=model, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"Training Pipeline {self.resource_name} is not configured to upload a " + "Model." + ) + + class AutoMLImageTrainingJob(_TrainingJob): _supported_training_schemas = ( schema.training_job.definition.automl_image_classification, @@ -3686,7 +4138,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -3784,7 +4236,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py new file mode 100644 index 0000000000..5d89360566 --- /dev/null +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -0,0 +1,488 @@ +import importlib +import pytest +from unittest import mock + +from google.cloud import aiplatform +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform.training_jobs import AutoMLForecastingTrainingJob + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +_TEST_BUCKET_NAME = "test-bucket" +_TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder" +_TEST_GCS_PATH = f"{_TEST_BUCKET_NAME}/{_TEST_GCS_PATH_WITHOUT_BUCKET}" +_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/" +_TEST_PROJECT = "test-project" + +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" +_TEST_METADATA_SCHEMA_URI_TIMESERIES = schema.dataset.metadata.time_series +_TEST_METADATA_SCHEMA_URI_NONTIMESERIES = schema.dataset.metadata.image + +_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [ + {"auto": {"column_name": "time"}}, + {"auto": {"column_name": "time_series_identifier"}}, + {"auto": {"column_name": "target"}}, + {"auto": {"column_name": "weight"}}, +] +_TEST_TRAINING_TARGET_COLUMN = "target" +_TEST_TRAINING_TIME_COLUMN = "time" +_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN = "time_series_identifier" +_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS = [] +_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS = [] +_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS = [] +_TEST_TRAINING_FORECAST_HORIZON = 10 +_TEST_TRAINING_DATA_GRANULARITY_UNIT = "day" +_TEST_TRAINING_DATA_GRANULARITY_COUNT = 1 +_TEST_TRAINING_CONTEXT_WINDOW = None +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS = True +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI = ( + "bq://path.to.table" +) +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION = False +_TEST_TRAINING_QUANTILES = None +_TEST_TRAINING_VALIDATION_OPTIONS = None +_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS = 1000 +_TEST_TRAINING_WEIGHT_COLUMN = "weight" +_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME = "minimize-rmse" +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + { + # required inputs + "targetColumn": _TEST_TRAINING_TARGET_COLUMN, + "timeColumn": _TEST_TRAINING_TIME_COLUMN, + "timeSeriesIdentifierColumn": _TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + "timeSeriesAttributeColumns": _TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + "unavailableAtForecastColumns": _TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + "availableAtForecastColumns": _TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + "forecastHorizon": _TEST_TRAINING_FORECAST_HORIZON, + "dataGranularity": { + "unit": _TEST_TRAINING_DATA_GRANULARITY_UNIT, + "quantity": _TEST_TRAINING_DATA_GRANULARITY_COUNT, + }, + "transformations": _TEST_TRAINING_COLUMN_TRANSFORMATIONS, + "trainBudgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + # optional inputs + "weightColumn": _TEST_TRAINING_WEIGHT_COLUMN, + "contextWindow": _TEST_TRAINING_CONTEXT_WINDOW, + "exportEvaluatedDataItemsConfig": { + "destinationBigqueryUri": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + "overrideExistingTable": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + }, + "quantiles": _TEST_TRAINING_QUANTILES, + "validationOptions": _TEST_TRAINING_VALIDATION_OPTIONS, + "optimizationObjective": _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + }, + struct_pb2.Value(), +) + +_TEST_DATASET_NAME = "test-dataset-name" + +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_TRAINING_FRACTION_SPLIT = 0.8 +_TEST_VALIDATION_FRACTION_SPLIT = 0.1 +_TEST_TEST_FRACTION_SPLIT = 0.1 +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" + +_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" + +_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345" + +_TEST_PIPELINE_RESOURCE_NAME = ( + "projects/my-project/locations/us-central1/trainingPipeline/12345" +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_time_series(): + ds = mock.MagicMock(datasets.TimeSeriesDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TIMESERIES, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_dataset_nontimeseries(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTIMESERIES, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +class TestAutoMLForecastingTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_time_series, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_forecasting, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_time_series, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_forecasting, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_dataset_time_series, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_dataset_time_series, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state From 8cf8333af5db5fc8d3e2fc342351e15aa43d1d9f Mon Sep 17 00:00:00 2001 From: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> Date: Wed, 5 May 2021 19:51:08 -0400 Subject: [PATCH 20/36] chore: add extras dependencies and install for testing (#369) --- noxfile.py | 2 +- setup.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 7b28a76f53..c5a1b43f35 100644 --- a/noxfile.py +++ b/noxfile.py @@ -95,7 +95,7 @@ def default(session): session.install("mock", "pytest", "pytest-cov", "-c", constraints_path) - session.install("-e", ".", "-c", constraints_path) + session.install("-e", ".[testing]", "-c", constraints_path) # Run py.test against the unit tests. session.run( diff --git a/setup.py b/setup.py index 84bb3c75d9..8b13352f84 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,12 @@ with io.open(readme_filename, encoding="utf-8") as readme_file: readme = readme_file.read() +tensorboard_extra_require = ["tensorflow-cpu >= 2.3.0, <=2.5.0rc"] +metadata_extra_require = ["pandas >= 1.0.0"] +full_extra_require = tensorboard_extra_require + metadata_extra_require +testing_extra_require = full_extra_require + ["grpcio-testing >= 1.37.1"] + + setuptools.setup( name=name, version=version, @@ -48,7 +54,12 @@ "google-cloud-storage >= 1.26.0, < 2.0.0dev", "google-cloud-bigquery >= 1.15.0, < 3.0.0dev", ), - extras_require={"full": ["pandas>=1.0.0"]}, + extras_require={ + "full": full_extra_require, + "metadata": metadata_extra_require, + "tensorboard": tensorboard_extra_require, + "testing": testing_extra_require, + }, python_requires=">=3.6", scripts=[], classifiers=[ From b13fb2c7e7c0bd19332cdde89d87891bb9281342 Mon Sep 17 00:00:00 2001 From: jialuzh <35091833+jialuzh@users.noreply.github.com> Date: Fri, 7 May 2021 12:35:02 -0700 Subject: [PATCH 21/36] fix: enable metadata dataframe related tests (#372) --- tests/unit/aiplatform/test_metadata.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py index 26e297426b..9a930dd3f5 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata.py @@ -542,10 +542,6 @@ def test_log_metrics_string_value_raise_error(self): with pytest.raises(TypeError): aiplatform.log_metrics({"test": "string"}) - # TODO: remove skip once koroko test would install extra required packages. - @pytest.mark.skip( - reason="Temporarily skip this test as extra required package are not installed in current setup" - ) @pytest.mark.usefixtures("get_context_mock") def test_get_experiment_df( self, list_executions_mock, query_execution_inputs_and_outputs_mock @@ -605,9 +601,6 @@ def test_get_experiment_df_wrong_schema(self): with pytest.raises(ValueError): aiplatform.get_experiment_df(_TEST_EXPERIMENT) - @pytest.mark.skip( - reason="Temporarily skip this test as extra required package are not installed in current setup" - ) @pytest.mark.usefixtures("get_pipeline_context_mock") def test_get_pipeline_df( self, list_executions_mock, query_execution_inputs_and_outputs_mock From be8ad490311ff4679457d04506f10089bc8e9a9f Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Mon, 10 May 2021 11:55:58 -0400 Subject: [PATCH 22/36] chore: remove deleted sample test --- ..._custom_training_managed_dataset_sample.py | 73 ------------------- ...ne_custom_training_managed_dataset_test.py | 70 ------------------ 2 files changed, 143 deletions(-) delete mode 100644 samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py delete mode 100644 samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py diff --git a/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py b/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py deleted file mode 100644 index 7d7dc6357c..0000000000 --- a/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_sample.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Union - -from google.cloud import aiplatform - - -# [START aiplatform_sdk_create_training_pipeline_custom_job_sample] -def create_training_pipeline_custom_training_managed_dataset_sample( - project: str, - location: str, - display_name: str, - script_path: str, - container_uri: str, - model_serving_container_image_uri: str, - dataset_id: int, - model_display_name: Optional[str] = None, - args: Optional[List[Union[str, float, int]]] = None, - replica_count: int = 0, - machine_type: str = "n1-standard-4", - accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", - accelerator_count: int = 0, - training_fraction_split: float = 0.8, - validation_fraction_split: float = 0.1, - test_fraction_split: float = 0.1, - sync: bool = True, -): - aiplatform.init(project=project, location=location) - - job = aiplatform.CustomTrainingJob( - display_name=display_name, - script_path=script_path, - container_uri=container_uri, - model_serving_container_image_uri=model_serving_container_image_uri, - ) - - my_image_ds = aiplatform.ImageDataset(dataset_id) - - model = job.run( - dataset=my_image_ds, - model_display_name=model_display_name, - args=args, - replica_count=replica_count, - machine_type=machine_type, - accelerator_type=accelerator_type, - accelerator_count=accelerator_count, - training_fraction_split=training_fraction_split, - validation_fraction_split=validation_fraction_split, - test_fraction_split=test_fraction_split, - sync=sync, - ) - - model.wait() - - print(model.display_name) - print(model.resource_name) - print(model.uri) - return model - - -# [END aiplatform_sdk_create_training_pipeline_custom_job_sample] diff --git a/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py b/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py deleted file mode 100644 index 4197f658b1..0000000000 --- a/samples/model-builder/create_training_pipeline_custom_training_managed_dataset_test.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import create_training_pipeline_custom_training_managed_dataset_sample -import test_constants as constants - - -def test_create_training_pipeline_custom_job_sample( - mock_sdk_init, - mock_image_dataset, - mock_init_custom_training_job, - mock_run_custom_training_job, - mock_get_image_dataset, -): - - create_training_pipeline_custom_training_managed_dataset_sample.create_training_pipeline_custom_training_managed_dataset_sample( - project=constants.PROJECT, - location=constants.LOCATION, - display_name=constants.DISPLAY_NAME, - args=constants.ARGS, - script_path=constants.SCRIPT_PATH, - container_uri=constants.CONTAINER_URI, - model_serving_container_image_uri=constants.CONTAINER_URI, - dataset_id=constants.RESOURCE_ID, - model_display_name=constants.DISPLAY_NAME_2, - replica_count=constants.REPLICA_COUNT, - machine_type=constants.MACHINE_TYPE, - accelerator_type=constants.ACCELERATOR_TYPE, - accelerator_count=constants.ACCELERATOR_COUNT, - training_fraction_split=constants.TRAINING_FRACTION_SPLIT, - validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, - test_fraction_split=constants.TEST_FRACTION_SPLIT, - ) - - mock_get_image_dataset.assert_called_once_with(constants.RESOURCE_ID) - - mock_sdk_init.assert_called_once_with( - project=constants.PROJECT, location=constants.LOCATION - ) - mock_init_custom_training_job.assert_called_once_with( - display_name=constants.DISPLAY_NAME, - script_path=constants.SCRIPT_PATH, - container_uri=constants.CONTAINER_URI, - model_serving_container_image_uri=constants.CONTAINER_URI, - ) - mock_run_custom_training_job.assert_called_once_with( - dataset=mock_image_dataset, - model_display_name=constants.DISPLAY_NAME_2, - args=constants.ARGS, - replica_count=constants.REPLICA_COUNT, - machine_type=constants.MACHINE_TYPE, - accelerator_type=constants.ACCELERATOR_TYPE, - accelerator_count=constants.ACCELERATOR_COUNT, - training_fraction_split=constants.TRAINING_FRACTION_SPLIT, - validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, - test_fraction_split=constants.TEST_FRACTION_SPLIT, - sync=True, - ) From 6242411d122d8aadad70a4f684ec0fbcad15fbd1 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Mon, 10 May 2021 12:58:04 -0400 Subject: [PATCH 23/36] chore: expose TimeSeriesDataset --- google/cloud/aiplatform/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 3e68e1d0d8..e56e57a2ad 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -23,6 +23,7 @@ ImageDataset, TabularDataset, TextDataset, + TimeSeriesDataset, VideoDataset, ) from google.cloud.aiplatform.models import Endpoint From 7215940b6e9795e2e1461df80b6640bb05bc5a80 Mon Sep 17 00:00:00 2001 From: Yicheng Fang <58752348+yfang1@users.noreply.github.com> Date: Mon, 10 May 2021 11:24:45 -0700 Subject: [PATCH 24/36] feat: adding TB.gcp uploader to python-aiplatform (#368) --- google/cloud/aiplatform/compat/__init__.py | 7 + .../aiplatform/compat/services/__init__.py | 4 + .../cloud/aiplatform/compat/types/__init__.py | 12 + .../cloud/aiplatform/tensorboard/__init__.py | 16 + .../cloud/aiplatform/tensorboard/uploader.py | 1442 ++++++++++++++++ .../aiplatform/tensorboard/uploader_main.py | 148 ++ google/cloud/aiplatform/utils.py | 10 + setup.py | 13 +- tests/unit/aiplatform/test_uploader.py | 1454 +++++++++++++++++ 9 files changed, 3104 insertions(+), 2 deletions(-) create mode 100644 google/cloud/aiplatform/tensorboard/__init__.py create mode 100644 google/cloud/aiplatform/tensorboard/uploader.py create mode 100644 google/cloud/aiplatform/tensorboard/uploader_main.py create mode 100644 tests/unit/aiplatform/test_uploader.py diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index 16cc83a9cd..980c554fe1 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -35,6 +35,7 @@ services.specialist_pool_service_client_v1beta1 ) services.metadata_service_client = services.metadata_service_client_v1beta1 + services.tensorboard_service_client = services.tensorboard_service_client_v1beta1 types.accelerator_type = types.accelerator_type_v1beta1 types.annotation = types.annotation_v1beta1 @@ -71,6 +72,12 @@ types.specialist_pool_service = types.specialist_pool_service_v1beta1 types.training_pipeline = types.training_pipeline_v1beta1 types.metadata_service = types.metadata_service_v1beta1 + types.tensorboard_service = types.tensorboard_service_v1beta1 + types.tensorboard_data = types.tensorboard_data_v1beta1 + types.tensorboard_experiment = types.tensorboard_experiment_v1beta1 + types.tensorboard_run = types.tensorboard_run_v1beta1 + types.tensorboard_service = types.tensorboard_service_v1beta1 + types.tensorboard_time_series = types.tensorboard_time_series_v1beta1 if DEFAULT_VERSION == V1: diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py index 8cbe922cbf..5c104ab41f 100644 --- a/google/cloud/aiplatform/compat/services/__init__.py +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -39,6 +39,9 @@ from google.cloud.aiplatform_v1beta1.services.metadata_service import ( client as metadata_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import ( + client as tensorboard_service_client_v1beta1, +) from google.cloud.aiplatform_v1.services.dataset_service import ( client as dataset_service_client_v1, @@ -80,4 +83,5 @@ prediction_service_client_v1beta1, specialist_pool_service_client_v1beta1, metadata_service_client_v1beta1, + tensorboard_service_client_v1beta1, ) diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index 047f1dee1d..f45bb2e11e 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -51,6 +51,12 @@ specialist_pool_service as specialist_pool_service_v1beta1, training_pipeline as training_pipeline_v1beta1, metadata_service as metadata_service_v1beta1, + tensorboard_service as tensorboard_service_v1beta1, + tensorboard_data as tensorboard_data_v1beta1, + tensorboard_experiment as tensorboard_experiment_v1beta1, + tensorboard_run as tensorboard_run_v1beta1, + tensorboard_service as tensorboard_service_v1beta1, + tensorboard_time_series as tensorboard_time_series_v1beta1, ) from google.cloud.aiplatform_v1.types import ( accelerator_type as accelerator_type_v1, @@ -157,4 +163,10 @@ specialist_pool_service_v1beta1, training_pipeline_v1beta1, metadata_service_v1beta1, + tensorboard_service_v1beta1, + tensorboard_data_v1beta1, + tensorboard_experiment_v1beta1, + tensorboard_run_v1beta1, + tensorboard_service_v1beta1, + tensorboard_time_series_v1beta1, ) diff --git a/google/cloud/aiplatform/tensorboard/__init__.py b/google/cloud/aiplatform/tensorboard/__init__.py new file mode 100644 index 0000000000..a6fbe4122f --- /dev/null +++ b/google/cloud/aiplatform/tensorboard/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py new file mode 100644 index 0000000000..57dcbedf60 --- /dev/null +++ b/google/cloud/aiplatform/tensorboard/uploader.py @@ -0,0 +1,1442 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Uploads a TensorBoard logdir to TensorBoard.gcp.""" +import contextlib +import functools +import json +import os +import time +import re +from typing import Callable, Dict, FrozenSet, Generator, Iterable, Optional, Tuple +import uuid + +import grpc +from tensorboard.backend import process_graph +from tensorboard.backend.event_processing.plugin_event_accumulator import ( + directory_loader, +) +from tensorboard.backend.event_processing.plugin_event_accumulator import ( + event_file_loader, +) +from tensorboard.backend.event_processing.plugin_event_accumulator import io_wrapper +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import types_pb2 +from tensorboard.plugins.graph import metadata as graph_metadata +from tensorboard.uploader import logdir_loader +from tensorboard.uploader import upload_tracker +from tensorboard.uploader import util +from tensorboard.uploader.proto import server_info_pb2 +from tensorboard.util import tb_logging +from tensorboard.util import tensor_util +import tensorflow as tf + +from google.api_core import exceptions +from google.cloud import storage +from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1 +from google.cloud.aiplatform.compat.types import ( + tensorboard_data_v1beta1 as tensorboard_data, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_experiment_v1beta1 as tensorboard_experiment, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_run_v1beta1 as tensorboard_run, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_service_v1beta1 as tensorboard_service, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_time_series_v1beta1 as tensorboard_time_series, +) +from google.protobuf import message +from google.protobuf import timestamp_pb2 as timestamp + +TensorboardServiceClient = tensorboard_service_client_v1beta1.TensorboardServiceClient + +# Minimum length of a logdir polling cycle in seconds. Shorter cycles will +# sleep to avoid spinning over the logdir, which isn't great for disks and can +# be expensive for network file systems. +_MIN_LOGDIR_POLL_INTERVAL_SECS = 1 + +# Maximum length of a base-128 varint as used to encode a 64-bit value +# (without the "msb of last byte is bit 63" optimization, to be +# compatible with protobuf and golang varints). +_MAX_VARINT64_LENGTH_BYTES = 10 + +# Default minimum interval between initiating WriteTensorbordRunData RPCs in +# milliseconds. +_DEFAULT_MIN_SCALAR_REQUEST_INTERVAL = 10 + +# Default maximum WriteTensorbordRunData request size in bytes. +_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 24 * (2 ** 10) # 24KiB + +# Default minimum interval between initiating WriteTensorbordRunData RPCs in +# milliseconds. +_DEFAULT_MIN_TENSOR_REQUEST_INTERVAL = 10 + +# Default minimum interval between initiating WriteTensorbordRunData RPCs in +# milliseconds. +_DEFAULT_MIN_BLOB_REQUEST_INTERVAL = 10 + +# Default maximum WriteTensorbordRunData request size in bytes. +_DEFAULT_MAX_TENSOR_REQUEST_SIZE = 512 * (2 ** 10) # 512KiB + +_DEFAULT_MAX_BLOB_REQUEST_SIZE = 4 * (2 ** 20) - 256 * (2 ** 10) # 4MiB-256KiB + +# Default maximum tensor point size in bytes. +_DEFAULT_MAX_TENSOR_POINT_SIZE = 16 * (2 ** 10) # 16KiB + +_DEFAULT_MAX_BLOB_SIZE = 10 * (2 ** 30) # 10GiB + +logger = tb_logging.get_logger() + + +class TensorBoardUploader(object): + """Uploads a TensorBoard logdir to TensorBoard.gcp.""" + + def __init__( + self, + experiment_name: str, + tensorboard_resource_name: str, + blob_storage_bucket: storage.Bucket, + blob_storage_folder: str, + writer_client: TensorboardServiceClient, + logdir: str, + allowed_plugins: FrozenSet[str], + experiment_display_name: Optional[str] = None, + upload_limits: Optional[server_info_pb2.UploadLimits] = None, + logdir_poll_rate_limiter: Optional[util.RateLimiter] = None, + rpc_rate_limiter: Optional[util.RateLimiter] = None, + tensor_rpc_rate_limiter: Optional[util.RateLimiter] = None, + blob_rpc_rate_limiter: Optional[util.RateLimiter] = None, + description: Optional[str] = None, + verbosity: int = 1, + one_shot: bool = False, + event_file_inactive_secs: Optional[int] = None, + run_name_prefix=None, + ): + """Constructs a TensorBoardUploader. + + Args: + experiment_name: Name of this experiment. Unique to the given + tensorboard_resource_name. + tensorboard_resource_name: Name of the Tensorboard resource with this + format + projects/{project}/locations/{location}/tensorboards/{tensorboard} + writer_client: a TensorBoardWriterService stub instance + logdir: path of the log directory to upload + experiment_display_name: The display name of the experiment. + allowed_plugins: collection of string plugin names; events will only be + uploaded if their time series's metadata specifies one of these plugin + names + upload_limits: instance of tensorboard.service.UploadLimits proto. + logdir_poll_rate_limiter: a `RateLimiter` to use to limit logdir polling + frequency, to avoid thrashing disks, especially on networked file + systems + rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency. + Note this limit applies at the level of single RPCs in the Scalar and + Tensor case, but at the level of an entire blob upload in the Blob + case-- which may require a few preparatory RPCs and a stream of chunks. + Note the chunk stream is internally rate-limited by backpressure from + the server, so it is not a concern that we do not explicitly rate-limit + within the stream here. + description: String description to assign to the experiment. + verbosity: Level of verbosity, an integer. Supported value: 0 - No upload + statistics is printed. 1 - Print upload statistics while uploading data + (default). + one_shot: Once uploading starts, upload only the existing data in the + logdir and then return immediately, instead of the default behavior of + continuing to listen for new data in the logdir and upload them when it + appears. + event_file_inactive_secs: Age in seconds of last write after which an + event file is considered inactive. If none then event file is never + considered inactive. + run_name_prefix: If present, all runs created by this invocation will have + their name prefixed by this value. + """ + self._experiment_name = experiment_name + self._experiment_display_name = experiment_display_name + self._tensorboard_resource_name = tensorboard_resource_name + self._blob_storage_bucket = blob_storage_bucket + self._blob_storage_folder = blob_storage_folder + self._api = writer_client + self._logdir = logdir + self._allowed_plugins = frozenset(allowed_plugins) + self._run_name_prefix = run_name_prefix + + self._upload_limits = upload_limits + if not self._upload_limits: + self._upload_limits = server_info_pb2.UploadLimits() + self._upload_limits.max_scalar_request_size = ( + _DEFAULT_MAX_SCALAR_REQUEST_SIZE + ) + self._upload_limits.min_scalar_request_interval = ( + _DEFAULT_MIN_SCALAR_REQUEST_INTERVAL + ) + self._upload_limits.min_tensor_request_interval = ( + _DEFAULT_MIN_TENSOR_REQUEST_INTERVAL + ) + self._upload_limits.max_tensor_request_size = ( + _DEFAULT_MAX_TENSOR_REQUEST_SIZE + ) + self._upload_limits.max_tensor_point_size = _DEFAULT_MAX_TENSOR_POINT_SIZE + self._upload_limits.min_blob_request_interval = ( + _DEFAULT_MIN_BLOB_REQUEST_INTERVAL + ) + self._upload_limits.max_blob_request_size = _DEFAULT_MAX_BLOB_REQUEST_SIZE + self._upload_limits.max_blob_size = _DEFAULT_MAX_BLOB_SIZE + + self._description = description + self._verbosity = verbosity + self._one_shot = one_shot + self._request_sender = None + if logdir_poll_rate_limiter is None: + self._logdir_poll_rate_limiter = util.RateLimiter( + _MIN_LOGDIR_POLL_INTERVAL_SECS + ) + else: + self._logdir_poll_rate_limiter = logdir_poll_rate_limiter + + if rpc_rate_limiter is None: + self._rpc_rate_limiter = util.RateLimiter( + self._upload_limits.min_scalar_request_interval / 1000 + ) + else: + self._rpc_rate_limiter = rpc_rate_limiter + + if tensor_rpc_rate_limiter is None: + self._tensor_rpc_rate_limiter = util.RateLimiter( + self._upload_limits.min_tensor_request_interval / 1000 + ) + else: + self._tensor_rpc_rate_limiter = tensor_rpc_rate_limiter + + if blob_rpc_rate_limiter is None: + self._blob_rpc_rate_limiter = util.RateLimiter( + self._upload_limits.min_blob_request_interval / 1000 + ) + else: + self._blob_rpc_rate_limiter = blob_rpc_rate_limiter + + def active_filter(secs): + return ( + not bool(event_file_inactive_secs) + or secs + event_file_inactive_secs >= time.time() + ) + + directory_loader_factory = functools.partial( + directory_loader.DirectoryLoader, + loader_factory=event_file_loader.TimestampedEventFileLoader, + path_filter=io_wrapper.IsTensorFlowEventsFile, + active_filter=active_filter, + ) + self._logdir_loader = logdir_loader.LogdirLoader( + self._logdir, directory_loader_factory + ) + self._tracker = upload_tracker.UploadTracker(verbosity=self._verbosity) + + def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperiment: + """Create an experiment or get an experiment. + + Attempts to create an experiment. If the experiment already exists and + creation fails then the experiment will be retrieved. + + Returns: + The created or retrieved experiment. + """ + logger.info("Creating experiment") + + tb_experiment = tensorboard_experiment.TensorboardExperiment( + description=self._description, display_name=self._experiment_display_name + ) + + try: + experiment = self._api.create_tensorboard_experiment( + parent=self._tensorboard_resource_name, + tensorboard_experiment=tb_experiment, + tensorboard_experiment_id=self._experiment_name, + ) + except exceptions.AlreadyExists: + logger.info("Creating experiment failed. Retrieving experiment.") + experiment_name = os.path.join( + self._tensorboard_resource_name, "experiments", self._experiment_name + ) + experiment = self._api.get_tensorboard_experiment(name=experiment_name) + return experiment + + def create_experiment(self): + """Creates an Experiment for this upload session and returns the ID.""" + + experiment = self._create_or_get_experiment() + self._experiment = experiment + self._request_sender = _BatchedRequestSender( + self._experiment.name, + self._api, + allowed_plugins=self._allowed_plugins, + upload_limits=self._upload_limits, + rpc_rate_limiter=self._rpc_rate_limiter, + tensor_rpc_rate_limiter=self._tensor_rpc_rate_limiter, + blob_rpc_rate_limiter=self._blob_rpc_rate_limiter, + blob_storage_bucket=self._blob_storage_bucket, + blob_storage_folder=self._blob_storage_folder, + tracker=self._tracker, + ) + + def get_experiment_resource_name(self): + return self._experiment.name + + def start_uploading(self): + """Blocks forever to continuously upload data from the logdir. + + Raises: + RuntimeError: If `create_experiment` has not yet been called. + ExperimentNotFoundError: If the experiment is deleted during the + course of the upload. + """ + if self._request_sender is None: + raise RuntimeError("Must call create_experiment() before start_uploading()") + while True: + self._logdir_poll_rate_limiter.tick() + self._upload_once() + if self._one_shot: + break + if self._one_shot and not self._tracker.has_data(): + logger.warning( + "One-shot mode was used on a logdir (%s) " + "without any uploadable data" % self._logdir + ) + + def _upload_once(self): + """Runs one upload cycle, sending zero or more RPCs.""" + logger.info("Starting an upload cycle") + + sync_start_time = time.time() + self._logdir_loader.synchronize_runs() + sync_duration_secs = time.time() - sync_start_time + logger.info("Logdir sync took %.3f seconds", sync_duration_secs) + + run_to_events = self._logdir_loader.get_run_events() + if self._run_name_prefix: + run_to_events = { + self._run_name_prefix + k: v for k, v in run_to_events.items() + } + with self._tracker.send_tracker(): + self._request_sender.send_requests(run_to_events) + + +class ExperimentNotFoundError(RuntimeError): + pass + + +class PermissionDeniedError(RuntimeError): + pass + + +class ExistingResourceNotFoundError(RuntimeError): + """Resource could not be created or retrieved.""" + + +class _OutOfSpaceError(Exception): + """Action could not proceed without overflowing request budget. + + This is a signaling exception (like `StopIteration`) used internally + by `_*RequestSender`; it does not mean that anything has gone wrong. + """ + + pass + + +class _BatchedRequestSender(object): + """Helper class for building requests that fit under a size limit. + + This class maintains stateful request builders for each of the possible + request types (scalars, tensors, and blobs). These accumulate batches + independently, each maintaining its own byte budget and emitting a request + when the batch becomes full. As a consequence, events of different types + will likely be sent to the backend out of order. E.g., in the extreme case, + a single tensor-flavored request may be sent only when the event stream is + exhausted, even though many more recent scalar events were sent earlier. + + This class is not threadsafe. Use external synchronization if + calling its methods concurrently. + """ + + def __init__( + self, + experiment_resource_name: str, + api: TensorboardServiceClient, + allowed_plugins: Iterable[str], + upload_limits: server_info_pb2.UploadLimits, + rpc_rate_limiter: util.RateLimiter, + tensor_rpc_rate_limiter: util.RateLimiter, + blob_rpc_rate_limiter: util.RateLimiter, + blob_storage_bucket: storage.Bucket, + blob_storage_folder: str, + tracker: upload_tracker.UploadTracker, + ): + """Constructs _BatchedRequestSender for the given experiment resource. + + Args: + experiment_resource_name: Name of the experiment resource of the form + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment} + api: Tensorboard service stub used to interact with experiment resource. + allowed_plugins: The plugins supported by the Tensorboard.gcp resource. + upload_limits: Upload limits for for api calls. + rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency. + Note this limit applies at the level of single RPCs in the Scalar and + Tensor case, but at the level of an entire blob upload in the Blob + case-- which may require a few preparatory RPCs and a stream of chunks. + Note the chunk stream is internally rate-limited by backpressure from + the server, so it is not a concern that we do not explicitly rate-limit + within the stream here. + tracker: Upload tracker to track information about uploads. + """ + self._experiment_resource_name = experiment_resource_name + self._api = api + self._tag_metadata = {} + self._allowed_plugins = frozenset(allowed_plugins) + self._tracker = tracker + self._run_to_request_sender: Dict[str, _ScalarBatchedRequestSender] = {} + self._run_to_tensor_request_sender: Dict[str, _TensorBatchedRequestSender] = {} + self._run_to_blob_request_sender: Dict[str, _BlobRequestSender] = {} + self._run_to_run_resource: Dict[str, tensorboard_run.TensorboardRun] = {} + self._scalar_request_sender_factory = functools.partial( + _ScalarBatchedRequestSender, + api=api, + rpc_rate_limiter=rpc_rate_limiter, + max_request_size=upload_limits.max_scalar_request_size, + tracker=self._tracker, + ) + self._tensor_request_sender_factory = functools.partial( + _TensorBatchedRequestSender, + api=api, + rpc_rate_limiter=tensor_rpc_rate_limiter, + max_request_size=upload_limits.max_tensor_request_size, + max_tensor_point_size=upload_limits.max_tensor_point_size, + tracker=self._tracker, + ) + self._blob_request_sender_factory = functools.partial( + _BlobRequestSender, + api=api, + rpc_rate_limiter=blob_rpc_rate_limiter, + max_blob_request_size=upload_limits.max_blob_request_size, + max_blob_size=upload_limits.max_blob_size, + blob_storage_bucket=blob_storage_bucket, + blob_storage_folder=blob_storage_folder, + tracker=self._tracker, + ) + + def send_requests( + self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] + ): + """Accepts a stream of TF events and sends batched write RPCs. + + Each sent request will be batched, the size of each batch depending on + the type of data (Scalar vs Tensor vs Blob) being sent. + + Args: + run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` + values, as returned by `LogdirLoader.get_run_events`. + + Raises: + RuntimeError: If no progress can be made because even a single + point is too large (say, due to a gigabyte-long tag name). + """ + + for (run_name, event, value) in self._run_values(run_to_events): + time_series_key = (run_name, value.tag) + + # The metadata for a time series is memorized on the first event. + # If later events arrive with a mismatching plugin_name, they are + # ignored with a warning. + metadata = self._tag_metadata.get(time_series_key) + first_in_time_series = False + if metadata is None: + first_in_time_series = True + metadata = value.metadata + self._tag_metadata[time_series_key] = metadata + + plugin_name = metadata.plugin_data.plugin_name + if value.HasField("metadata") and ( + plugin_name != value.metadata.plugin_data.plugin_name + ): + logger.warning( + "Mismatching plugin names for %s. Expected %s, found %s.", + time_series_key, + metadata.plugin_data.plugin_name, + value.metadata.plugin_data.plugin_name, + ) + continue + if plugin_name not in self._allowed_plugins: + if first_in_time_series: + logger.info( + "Skipping time series %r with unsupported plugin name %r", + time_series_key, + plugin_name, + ) + continue + self._tracker.add_plugin_name(plugin_name) + # If this is the first time we've seen this run create a new run resource + # and an associated request sender. + if run_name not in self._run_to_run_resource: + self._create_or_get_run_resource(run_name) + self._run_to_request_sender[ + run_name + ] = self._scalar_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_tensor_request_sender[ + run_name + ] = self._tensor_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_blob_request_sender[ + run_name + ] = self._blob_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + + if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: + self._run_to_request_sender[run_name].add_event(event, value, metadata) + elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: + self._run_to_tensor_request_sender[run_name].add_event( + event, value, metadata + ) + elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: + self._run_to_blob_request_sender[run_name].add_event( + event, value, metadata + ) + + for scalar_request_sender in self._run_to_request_sender.values(): + scalar_request_sender.flush() + + for tensor_request_sender in self._run_to_tensor_request_sender.values(): + tensor_request_sender.flush() + + for blob_request_sender in self._run_to_blob_request_sender.values(): + blob_request_sender.flush() + + def _create_or_get_run_resource(self, run_name: str): + """Creates a new Run Resource in current Tensorboard Experiment resource. + + Args: + run_name: The display name of this run. + """ + tb_run = tensorboard_run.TensorboardRun() + tb_run.display_name = run_name + try: + tb_run = self._api.create_tensorboard_run( + parent=self._experiment_resource_name, + tensorboard_run=tb_run, + tensorboard_run_id=str(uuid.uuid4()), + ) + except exceptions.InvalidArgument as e: + # If the run name already exists then retrieve it + if "already exist" in e.message: + runs_pages = self._api.list_tensorboard_runs( + parent=self._experiment_resource_name + ) + for tb_run in runs_pages: + if tb_run.display_name == run_name: + break + + if tb_run.display_name != run_name: + raise ExistingResourceNotFoundError( + "Run with name %s already exists but is not resource list." + % run_name + ) + else: + raise + + self._run_to_run_resource[run_name] = tb_run + + def _run_values( + self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] + ) -> Generator[ + Tuple[str, tf.compat.v1.Event, tf.compat.v1.Summary.Value], None, None + ]: + """Helper generator to create a single stream of work items. + + Note that `dataclass_compat` may emit multiple variants of + the same event, for backwards compatibility. Thus this stream should + be filtered to obtain the desired version of each event. Here, we + ignore any event that does not have a `summary` field. + + Furthermore, the events emitted here could contain values that do not + have `metadata.data_class` set; these too should be ignored. In + `_send_summary_value(...)` above, we switch on `metadata.data_class` + and drop any values with an unknown (i.e., absent or unrecognized) + `data_class`. + + Args: + run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` + values, as returned by `LogdirLoader.get_run_events`. + + Yields: + Tuple of run name, tf.compat.v1.Event, tf.compat.v1.Summary.Value per + value. + """ + # Note that this join in principle has deletion anomalies: if the input + # stream contains runs with no events, or events with no values, we'll + # lose that information. This is not a problem: we would need to prune + # such data from the request anyway. + for (run_name, events) in run_to_events.items(): + for event in events: + _filter_graph_defs(event) + for value in event.summary.value: + yield (run_name, event, value) + + +class _TimeSeriesResourceManager(object): + """Helper class managing Time Series resources.""" + + def __init__(self, run_resource_id: str, api: TensorboardServiceClient): + """Constructor for _TimeSeriesResourceManager. + + Args: + run_resource_id: The resource id for the run with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + api: TensorboardServiceStub + """ + self._run_resource_id = run_resource_id + self._api = api + self._tag_to_time_series_proto: Dict[ + str, tensorboard_time_series.TensorboardTimeSeries + ] = {} + + def get_or_create( + self, + tag_name: str, + time_series_resource_creator: Callable[ + [], tensorboard_time_series.TensorboardTimeSeries + ], + ) -> tensorboard_time_series.TensorboardTimeSeries: + """get a time series resource with given tag_name, and create a new one on + + OnePlatform if not present. + + Args: + tag_name: The tag name of the time series in the Tensorboard log dir. + time_series_resource_creator: A callable that produces a TimeSeries for + creation. + """ + if tag_name in self._tag_to_time_series_proto: + return self._tag_to_time_series_proto[tag_name] + + time_series = time_series_resource_creator() + time_series.display_name = tag_name + try: + time_series = self._api.create_tensorboard_time_series( + parent=self._run_resource_id, tensorboard_time_series=time_series + ) + except exceptions.InvalidArgument as e: + # If the time series display name already exists then retrieve it + if "already exist" in e.message: + list_of_time_series = self._api.list_tensorboard_time_series( + request=tensorboard_service.ListTensorboardTimeSeriesRequest( + parent=self._run_resource_id, + filter="display_name = {}".format(json.dumps(str(tag_name))), + ) + ) + num = 0 + for ts in list_of_time_series: + time_series = ts + num += 1 + break + if num != 1: + raise ValueError( + "More than one time series resource found with display_name: {}".format( + tag_name + ) + ) + else: + raise + + self._tag_to_time_series_proto[tag_name] = time_series + return time_series + + +class _ScalarBatchedRequestSender(object): + """Helper class for building requests that fit under a size limit. + + This class accumulates a current request. `add_event(...)` may or may not + send the request (and start a new one). After all `add_event(...)` calls + are complete, a final call to `flush()` is needed to send the final request. + + This class is not threadsafe. Use external synchronization if calling its + methods concurrently. + """ + + def __init__( + self, + run_resource_id: str, + api: TensorboardServiceClient, + rpc_rate_limiter: util.RateLimiter, + max_request_size: int, + tracker: upload_tracker.UploadTracker, + ): + """Constructer for _ScalarBatchedRequestSender. + + Args: + run_resource_id: The resource id for the run with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + api: TensorboardServiceStub + rpc_rate_limiter: until.RateLimiter to limit rate of this request sender + max_request_size: max number of bytes to send + tracker: + """ + self._run_resource_id = run_resource_id + self._api = api + self._rpc_rate_limiter = rpc_rate_limiter + self._byte_budget_manager = _ByteBudgetManager(max_request_size) + self._tracker = tracker + + # cache: map from Tensorboard tag to TimeSeriesData + # cleared whenever a new request is created + self._tag_to_time_series_data: Dict[str, tensorboard_data.TimeSeriesData] = {} + + self._time_series_resource_manager = _TimeSeriesResourceManager( + self._run_resource_id, self._api + ) + self._new_request() + + def _new_request(self): + """Allocates a new request and refreshes the budget.""" + self._request = tensorboard_service.WriteTensorboardRunDataRequest() + self._tag_to_time_series_data.clear() + self._num_values = 0 + self._request.tensorboard_run = self._run_resource_id + self._byte_budget_manager.reset(self._request) + + def add_event( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + """Attempts to add the given event to the current request. + + If the event cannot be added to the current request because the byte + budget is exhausted, the request is flushed, and the event is added + to the next request. + + Args: + event: The tf.compat.v1.Event event containing the value. + value: A scalar tf.compat.v1.Summary.Value. + metadata: SummaryMetadata of the event. + """ + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + self.flush() + # Try again. This attempt should never produce OutOfSpaceError + # because we just flushed. + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + raise RuntimeError("add_event failed despite flush") + + def _add_event_internal( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + self._num_values += 1 + time_series_data_proto = self._tag_to_time_series_data.get(value.tag) + if time_series_data_proto is None: + time_series_data_proto = self._create_time_series_data(value.tag, metadata) + self._create_point(time_series_data_proto, event, value) + + def flush(self): + """Sends the active request after removing empty runs and tags. + + Starts a new, empty active request. + """ + request = self._request + request.time_series_data = list(self._tag_to_time_series_data.values()) + _prune_empty_time_series(request) + if not request.time_series_data: + return + + self._rpc_rate_limiter.tick() + + with _request_logger(request): + with self._tracker.scalars_tracker(self._num_values): + try: + self._api.write_tensorboard_run_data( + tensorboard_run=self._run_resource_id, + time_series_data=request.time_series_data, + ) + except grpc.RpcError as e: + if ( + hasattr(e, "code") + and getattr(e, "code")() == grpc.StatusCode.NOT_FOUND + ): + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) + + self._new_request() + + def _create_time_series_data( + self, tag_name: str, metadata: tf.compat.v1.SummaryMetadata + ) -> tensorboard_data.TimeSeriesData: + """Adds a time_series for the tag_name, if there's space. + + Args: + tag_name: String name of the tag to add (as `value.tag`). + + Returns: + The TimeSeriesData in _request proto with the given tag name. + + Raises: + _OutOfSpaceError: If adding the tag would exceed the remaining + request budget. + """ + time_series_data_proto = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id=self._time_series_resource_manager.get_or_create( + tag_name, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=tag_name, + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + plugin_name=metadata.plugin_data.plugin_name, + plugin_data=metadata.plugin_data.content, + ), + ).name.split("/")[-1], + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + ) + + self._request.time_series_data.extend([time_series_data_proto]) + self._byte_budget_manager.add_time_series(time_series_data_proto) + self._tag_to_time_series_data[tag_name] = time_series_data_proto + return time_series_data_proto + + def _create_point( + self, + time_series_proto: tensorboard_data.TimeSeriesData, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + ): + """Adds a scalar point to the given tag, if there's space. + + Args: + time_series_proto: TimeSeriesData proto to which to add a point. + event: Enclosing `Event` proto with the step and wall time data. + value: Scalar `Summary.Value` proto with the actual scalar data. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining + request budget. + """ + scalar_proto = tensorboard_data.Scalar( + value=tensor_util.make_ndarray(value.tensor).item() + ) + point = tensorboard_data.TimeSeriesDataPoint( + step=event.step, + scalar=scalar_proto, + wall_time=timestamp.Timestamp( + seconds=int(event.wall_time), + nanos=int(round((event.wall_time % 1) * 10 ** 9)), + ), + ) + time_series_proto.values.extend([point]) + try: + self._byte_budget_manager.add_point(point) + except _OutOfSpaceError: + time_series_proto.values.pop() + raise + + +class _TensorBatchedRequestSender(object): + """Helper class for building WriteTensor() requests that fit under a size limit. + + This class accumulates a current request. `add_event(...)` may or may not + send the request (and start a new one). After all `add_event(...)` calls + are complete, a final call to `flush()` is needed to send the final request. + This class is not threadsafe. Use external synchronization if calling its + methods concurrently. + """ + + def __init__( + self, + run_resource_id: str, + api: TensorboardServiceClient, + rpc_rate_limiter: util.RateLimiter, + max_request_size: int, + max_tensor_point_size: int, + tracker: upload_tracker.UploadTracker, + ): + """Constructer for _TensorBatchedRequestSender. + + Args: + run_resource_id: The resource id for the run with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + api: TensorboardServiceStub + rpc_rate_limiter: until.RateLimiter to limit rate of this request sender + max_request_size: max number of bytes to send + tracker: + """ + self._run_resource_id = run_resource_id + self._api = api + self._rpc_rate_limiter = rpc_rate_limiter + self._byte_budget_manager = _ByteBudgetManager(max_request_size) + self._max_tensor_point_size = max_tensor_point_size + self._tracker = tracker + + # cache: map from Tensorboard tag to TimeSeriesData + # cleared whenever a new request is created + self._tag_to_time_series_data: Dict[str, tensorboard_data.TimeSeriesData] = {} + + self._time_series_resource_manager = _TimeSeriesResourceManager( + run_resource_id, api + ) + self._new_request() + + def _new_request(self): + """Allocates a new request and refreshes the budget.""" + self._request = tensorboard_service.WriteTensorboardRunDataRequest() + self._tag_to_time_series_data.clear() + self._num_values = 0 + self._request.tensorboard_run = self._run_resource_id + self._byte_budget_manager.reset(self._request) + self._num_values = 0 + self._num_values_skipped = 0 + self._tensor_bytes = 0 + self._tensor_bytes_skipped = 0 + + def add_event( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + """Attempts to add the given event to the current request. + + If the event cannot be added to the current request because the byte + budget is exhausted, the request is flushed, and the event is added + to the next request. + """ + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + self.flush() + # Try again. This attempt should never produce OutOfSpaceError + # because we just flushed. + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + raise RuntimeError("add_event failed despite flush") + + def _add_event_internal( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + self._num_values += 1 + time_series_data_proto = self._tag_to_time_series_data.get(value.tag) + if time_series_data_proto is None: + time_series_data_proto = self._create_time_series_data(value.tag, metadata) + self._create_point(time_series_data_proto, event, value) + + def flush(self): + """Sends the active request after removing empty runs and tags. + + Starts a new, empty active request. + """ + request = self._request + request.time_series_data = list(self._tag_to_time_series_data.values()) + _prune_empty_time_series(request) + if not request.time_series_data: + return + + self._rpc_rate_limiter.tick() + + with _request_logger(request): + with self._tracker.tensors_tracker( + self._num_values, + self._num_values_skipped, + self._tensor_bytes, + self._tensor_bytes_skipped, + ): + try: + self._api.write_tensorboard_run_data( + tensorboard_run=self._run_resource_id, + time_series_data=request.time_series_data, + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) + + self._new_request() + + def _create_time_series_data( + self, tag_name: str, metadata: tf.compat.v1.SummaryMetadata + ) -> tensorboard_data.TimeSeriesData: + """Adds a time_series for the tag_name, if there's space. + + Args: + tag_name: String name of the tag to add (as `value.tag`). + metadata: SummaryMetadata of the event. + + Returns: + The TimeSeriesData in _request proto with the given tag name. + + Raises: + _OutOfSpaceError: If adding the tag would exceed the remaining + request budget. + """ + time_series_data_proto = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id=self._time_series_resource_manager.get_or_create( + tag_name, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=tag_name, + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR, + plugin_name=metadata.plugin_data.plugin_name, + plugin_data=metadata.plugin_data.content, + ), + ).name.split("/")[-1], + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR, + ) + + self._request.time_series_data.extend([time_series_data_proto]) + self._byte_budget_manager.add_time_series(time_series_data_proto) + self._tag_to_time_series_data[tag_name] = time_series_data_proto + return time_series_data_proto + + def _create_point( + self, + time_series_proto: tensorboard_data.TimeSeriesData, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + ): + """Adds a tensor point to the given tag, if there's space. + + Args: + tag_proto: `WriteTensorRequest.Tag` proto to which to add a point. + event: Enclosing `Event` proto with the step and wall time data. + value: Tensor `Summary.Value` proto with the actual tensor data. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining + request budget. + """ + point = tensorboard_data.TimeSeriesDataPoint( + step=event.step, + tensor=tensorboard_data.TensorboardTensor( + value=value.tensor.SerializeToString() + ), + wall_time=timestamp.Timestamp( + seconds=int(event.wall_time), + nanos=int(round((event.wall_time % 1) * 10 ** 9)), + ), + ) + + self._num_values += 1 + tensor_size = len(point.tensor.value) + self._tensor_bytes += tensor_size + if tensor_size > self._max_tensor_point_size: + logger.warning( + "Tensor too large; skipping. " "Size %d exceeds limit of %d bytes.", + tensor_size, + self._max_tensor_point_size, + ) + self._num_values_skipped += 1 + self._tensor_bytes_skipped += tensor_size + return + + self._validate_tensor_value( + value.tensor, value.tag, event.step, event.wall_time + ) + + time_series_proto.values.extend([point]) + + try: + self._byte_budget_manager.add_point(point) + except _OutOfSpaceError: + time_series_proto.values.pop() + raise + + def _validate_tensor_value(self, tensor_proto, tag, step, wall_time): + """Validate a TensorProto by attempting to parse it.""" + try: + tensor_util.make_ndarray(tensor_proto) + except ValueError as error: + raise ValueError( + "The uploader failed to upload a tensor. This seems to be " + "due to a malformation in the tensor, which may be caused by " + "a bug in the process that wrote the tensor.\n\n" + "The tensor has tag '%s' and is at step %d and wall_time %.6f.\n\n" + "Original error:\n%s" % (tag, step, wall_time, error) + ) + + +class _ByteBudgetManager(object): + """Helper class for managing the request byte budget for certain RPCs. + + This should be used for RPCs that organize data by Runs, Tags, and Points, + specifically WriteScalar and WriteTensor. + + Any call to add_time_series() or add_point() may raise an + _OutOfSpaceError, which is non-fatal. It signals to the caller that they + should flush the current request and begin a new one. + + For more information on the protocol buffer encoding and how byte cost + can be calculated, visit: + + https://developers.google.com/protocol-buffers/docs/encoding + """ + + def __init__(self, max_bytes: int): + # The remaining number of bytes that we may yet add to the request. + self._byte_budget = None # type: int + self._max_bytes = max_bytes + + def reset(self, base_request: tensorboard_service.WriteTensorboardRunDataRequest): + """Resets the byte budget and calculates the cost of the base request. + + Args: + base_request: Base request. + + Raises: + _OutOfSpaceError: If the size of the request exceeds the entire + request byte budget. + """ + self._byte_budget = self._max_bytes + self._byte_budget -= ( + base_request._pb.ByteSize() + ) # pylint: disable=protected-access + if self._byte_budget < 0: + raise _OutOfSpaceError("Byte budget too small for base request") + + def add_time_series(self, time_series_proto: tensorboard_data.TimeSeriesData): + """Integrates the cost of a tag proto into the byte budget. + + Args: + time_series_proto: The proto representing a time series. + + Raises: + _OutOfSpaceError: If adding the time_series would exceed the remaining + request budget. + """ + cost = ( + # The size of the tag proto without any tag fields set. + time_series_proto._pb.ByteSize() # pylint: disable=protected-access + # The size of the varint that describes the length of the tag + # proto. We can't yet know the final size of the tag proto -- we + # haven't yet set any point values -- so we can't know the final + # size of this length varint. We conservatively assume it is maximum + # size. + + _MAX_VARINT64_LENGTH_BYTES + # The size of the proto key. + + 1 + ) + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + + def add_point(self, point_proto: tensorboard_data.TimeSeriesDataPoint): + """Integrates the cost of a point proto into the byte budget. + + Args: + point_proto: The proto representing a point. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining request + budget. + """ + submessage_cost = point_proto._pb.ByteSize() # pylint: disable=protected-access + cost = ( + # The size of the point proto. + submessage_cost + # The size of the varint that describes the length of the point + # proto. + + _varint_cost(submessage_cost) + # The size of the proto key. + + 1 + ) + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + + +class _BlobRequestSender(object): + """Uploader for blob-type event data. + + Unlike the other types, this class does not accumulate events in batches; + every blob is sent individually and immediately. Nonetheless we retain + the `add_event()`/`flush()` structure for symmetry. + + This class is not threadsafe. Use external synchronization if calling its + methods concurrently. + """ + + def __init__( + self, + run_resource_id: str, + api: TensorboardServiceClient, + rpc_rate_limiter: util.RateLimiter, + max_blob_request_size: int, + max_blob_size: int, + blob_storage_bucket: storage.Bucket, + blob_storage_folder: str, + tracker: upload_tracker.UploadTracker, + ): + self._run_resource_id = run_resource_id + self._api = api + self._rpc_rate_limiter = rpc_rate_limiter + self._max_blob_request_size = max_blob_request_size + self._max_blob_size = max_blob_size + self._tracker = tracker + self._time_series_resource_manager = _TimeSeriesResourceManager( + run_resource_id, api + ) + + self._bucket = blob_storage_bucket + self._folder = blob_storage_folder + + self._new_request() + + def _new_request(self): + """Declares the previous event complete.""" + self._event = None + self._value = None + self._metadata = None + + def add_event( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + """Attempts to add the given event to the current request. + + If the event cannot be added to the current request because the byte + budget is exhausted, the request is flushed, and the event is added + to the next request. + """ + if self._value: + raise RuntimeError("Tried to send blob while another is pending") + self._event = event # provides step and possibly plugin_name + self._value = value + self._blobs = tensor_util.make_ndarray(self._value.tensor) + if self._blobs.ndim == 1: + self._metadata = metadata + self.flush() + else: + logger.warning( + "A blob sequence must be represented as a rank-1 Tensor. " + "Provided data has rank %d, for run %s, tag %s, step %s ('%s' plugin) .", + self._blobs.ndim, + self._run_resource_id, + self._value.tag, + self._event.step, + metadata.plugin_data.plugin_name, + ) + # Skip this upload. + self._new_request() + + def flush(self): + """Sends the current blob sequence fully, and clears it to make way for the next.""" + if not self._value: + self._new_request() + return + + time_series_proto = self._time_series_resource_manager.get_or_create( + self._value.tag, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=self._value.tag, + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE, + plugin_name=self._metadata.plugin_data.plugin_name, + plugin_data=self._metadata.plugin_data.content, + ), + ) + m = re.match( + ".*/tensorboards/(.*)/experiments/(.*)/runs/(.*)/timeSeries/(.*)", + time_series_proto.name, + ) + blob_path_prefix = "tensorboard-{}/{}/{}/{}".format(m[1], m[2], m[3], m[4]) + blob_path_prefix = ( + "{}/{}".format(self._folder, blob_path_prefix) + if self._folder + else blob_path_prefix + ) + sent_blob_ids = [] + for blob in self._blobs: + self._rpc_rate_limiter.tick() + with self._tracker.blob_tracker(len(blob)) as blob_tracker: + blob_id = self._send_blob(blob, blob_path_prefix) + if blob_id is not None: + sent_blob_ids.append(str(blob_id)) + blob_tracker.mark_uploaded(blob_id is not None) + + data_point = tensorboard_data.TimeSeriesDataPoint( + step=self._event.step, + blobs=tensorboard_data.TensorboardBlobSequence( + values=[ + tensorboard_data.TensorboardBlob(id=blob_id) + for blob_id in sent_blob_ids + ] + ), + wall_time=timestamp.Timestamp( + seconds=int(self._event.wall_time), + nanos=int(round((self._event.wall_time % 1) * 10 ** 9)), + ), + ) + + time_series_data_proto = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id=time_series_proto.name.split("/")[-1], + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE, + values=[data_point], + ) + request = tensorboard_service.WriteTensorboardRunDataRequest( + time_series_data=[time_series_data_proto] + ) + + _prune_empty_time_series(request) + if not request.time_series_data: + return + + with _request_logger(request): + try: + self._api.write_tensorboard_run_data( + tensorboard_run=self._run_resource_id, + time_series_data=request.time_series_data, + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) + + self._new_request() + + def _send_blob(self, blob, blob_path_prefix): + """Sends a single blob to a GCS bucket in the consumer project. + + The blob will not be sent if it is too large. + + Returns: + The ID of blob successfully sent. + """ + if len(blob) > self._max_blob_size: + logger.warning( + "Blob too large; skipping. Size %d exceeds limit of %d bytes.", + len(blob), + self._max_blob_size, + ) + return None + + blob_id = uuid.uuid4() + blob_path = ( + "{}/{}".format(blob_path_prefix, blob_id) if blob_path_prefix else blob_id + ) + self._bucket.blob(blob_path).upload_from_string(blob) + return blob_id + + +@contextlib.contextmanager +def _request_logger(request: tensorboard_service.WriteTensorboardRunDataRequest): + """Context manager to log request size and duration.""" + upload_start_time = time.time() + request_bytes = request._pb.ByteSize() # pylint: disable=protected-access + logger.info("Trying request of %d bytes", request_bytes) + yield + upload_duration_secs = time.time() - upload_start_time + logger.info( + "Upload of (%d bytes) took %.3f seconds", request_bytes, upload_duration_secs, + ) + + +def _varint_cost(n: int): + """Computes the size of `n` encoded as an unsigned base-128 varint. + + This should be consistent with the proto wire format: + + + Args: + n: A non-negative integer. + + Returns: + An integer number of bytes. + """ + result = 1 + while n >= 128: + result += 1 + n >>= 7 + return result + + +def _prune_empty_time_series( + request: tensorboard_service.WriteTensorboardRunDataRequest, +): + """Removes empty time_series from request.""" + for (time_series_idx, time_series_data) in reversed( + list(enumerate(request.time_series_data)) + ): + if not time_series_data.values: + del request.time_series_data[time_series_idx] + + +def _filter_graph_defs(event: tf.compat.v1.Event): + """Filters graph definitions. + + Args: + event: tf.compat.v1.Event to filter. + """ + for v in event.summary.value: + if v.metadata.plugin_data.plugin_name != graph_metadata.PLUGIN_NAME: + continue + if v.tag == graph_metadata.RUN_GRAPH_NAME: + data = list(v.tensor.string_val) + filtered_data = [_filtered_graph_bytes(x) for x in data] + filtered_data = [x for x in filtered_data if x is not None] + if filtered_data != data: + new_tensor = tensor_util.make_tensor_proto( + filtered_data, dtype=types_pb2.DT_STRING + ) + v.tensor.CopyFrom(new_tensor) + + +def _filtered_graph_bytes(graph_bytes: bytes): + """Prepares the graph to be served to the front-end. + + For now, it supports filtering out attributes that are too large to be shown + in the graph UI. + + Args: + graph_bytes: Graph definition. + + Returns: + Filtered graph. + """ + try: + graph_def = graph_pb2.GraphDef().FromString(graph_bytes) + # The reason for the RuntimeWarning catch here is b/27494216, whereby + # some proto parsers incorrectly raise that instead of DecodeError + # on certain kinds of malformed input. Triggering this seems to require + # a combination of mysterious circumstances. + except (message.DecodeError, RuntimeWarning): + logger.warning( + "Could not parse GraphDef of size %d. Skipping.", len(graph_bytes), + ) + return None + # Use the default filter parameters: + # limit_attr_size=1024, large_attrs_key="_too_large_attrs" + process_graph.prepare_graph_for_ui(graph_def) + return graph_def.SerializeToString() diff --git a/google/cloud/aiplatform/tensorboard/uploader_main.py b/google/cloud/aiplatform/tensorboard/uploader_main.py new file mode 100644 index 0000000000..60298b5e5c --- /dev/null +++ b/google/cloud/aiplatform/tensorboard/uploader_main.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Launches Tensorboard Uploader for TB.GCP.""" +import re + +from absl import app +from absl import flags +import grpc +from tensorboard.plugins.scalar import metadata as scalar_metadata +from tensorboard.plugins.distribution import metadata as distribution_metadata +from tensorboard.plugins.histogram import metadata as histogram_metadata +from tensorboard.plugins.text import metadata as text_metadata +from tensorboard.plugins.hparams import metadata as hparams_metadata +from tensorboard.plugins.image import metadata as images_metadata +from tensorboard.plugins.graph import metadata as graphs_metadata + +from google.cloud import storage +from google.cloud import aiplatform +from google.cloud.aiplatform.tensorboard import uploader +from google.cloud.aiplatform.utils import TensorboardClientWithOverride + +FLAGS = flags.FLAGS +flags.DEFINE_string("experiment_name", None, "The name of the Cloud AI Experiment.") +flags.DEFINE_string( + "experiment_display_name", None, "The display name of the Cloud AI Experiment." +) +flags.DEFINE_string("logdir", None, "Tensorboard log directory to upload") +flags.DEFINE_bool("one_shot", False, "Iterate through logdir once to upload.") +flags.DEFINE_string("env", "prod", "Environment which this tensorboard belongs to.") +flags.DEFINE_string( + "tensorboard_resource_name", + None, + "Tensorboard resource to create this experiment in. ", +) +flags.DEFINE_integer( + "event_file_inactive_secs", + None, + "Age in seconds of last write after which an event file is considered " "inactive.", +) +flags.DEFINE_string( + "run_name_prefix", + None, + "If present, all runs created by this invocation will have their name " + "prefixed by this value.", +) + +flags.DEFINE_multi_string( + "allowed_plugins", + [ + scalar_metadata.PLUGIN_NAME, + histogram_metadata.PLUGIN_NAME, + distribution_metadata.PLUGIN_NAME, + text_metadata.PLUGIN_NAME, + hparams_metadata.PLUGIN_NAME, + images_metadata.PLUGIN_NAME, + graphs_metadata.PLUGIN_NAME, + ], + "Plugins allowed by the Uploader.", +) + +flags.mark_flags_as_required(["experiment_name", "logdir", "tensorboard_resource_name"]) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + m = re.match( + "projects/(.*)/locations/(.*)/tensorboards/.*", FLAGS.tensorboard_resource_name + ) + project_id = m[1] + region = m[2] + api_client = aiplatform.initializer.global_config.create_client( + client_class=TensorboardClientWithOverride, location_override=region, + ) + + try: + tensorboard = api_client.get_tensorboard(name=FLAGS.tensorboard_resource_name) + except grpc.RpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.NOT_FOUND: + raise app.UsageError( + "Tensorboard resource %s not found" % FLAGS.tensorboard_resource_name, + exitcode=0, + ) + raise + + if tensorboard.blob_storage_path_prefix: + path_prefix = tensorboard.blob_storage_path_prefix + "/" + first_slash_index = path_prefix.find("/") + bucket_name = path_prefix[:first_slash_index] + blob_storage_bucket = storage.Client(project=project_id).bucket(bucket_name) + blob_storage_folder = path_prefix[first_slash_index + 1 :] + else: + raise app.UsageError( + "Tensorboard resource {} is obsolete. Please create a new one.".format( + FLAGS.tensorboard_resource_name + ), + exitcode=0, + ) + + tb_uploader = uploader.TensorBoardUploader( + experiment_name=FLAGS.experiment_name, + experiment_display_name=FLAGS.experiment_display_name, + tensorboard_resource_name=tensorboard.name, + blob_storage_bucket=blob_storage_bucket, + blob_storage_folder=blob_storage_folder, + allowed_plugins=FLAGS.allowed_plugins, + writer_client=api_client, + logdir=FLAGS.logdir, + one_shot=FLAGS.one_shot, + event_file_inactive_secs=FLAGS.event_file_inactive_secs, + run_name_prefix=FLAGS.run_name_prefix, + ) + + tb_uploader.create_experiment() + + print( + "View your Tensorboard at https://{}/experiment/{}".format( + "tensorboard-gcp-prod.uc.r.appspot.com", + tb_uploader.get_experiment_resource_name().replace("/", "+"), + ) + ) + if FLAGS.one_shot: + tb_uploader._upload_once() # pylint: disable=protected-access + else: + tb_uploader.start_uploading() + + +def run_main(): + app.run(main) + + +if __name__ == "__main__": + run_main() diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index a77e491801..ff86fc1cb8 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -37,6 +37,7 @@ pipeline_service_client_v1beta1, prediction_service_client_v1beta1, metadata_service_client_v1beta1, + tensorboard_service_client_v1beta1, ) from google.cloud.aiplatform.compat.services import ( dataset_service_client_v1, @@ -471,6 +472,14 @@ class MetadataClientWithOverride(ClientWithOverride): ) +class TensorboardClientWithOverride(ClientWithOverride): + _is_temporary = False + _default_version = compat.V1BETA1 + _version_map = ( + (compat.V1BETA1, tensorboard_service_client_v1beta1.TensorboardServiceClient), + ) + + AiPlatformServiceClientWithOverride = TypeVar( "AiPlatformServiceClientWithOverride", DatasetClientWithOverride, @@ -480,6 +489,7 @@ class MetadataClientWithOverride(ClientWithOverride): PipelineClientWithOverride, PredictionClientWithOverride, MetadataClientWithOverride, + TensorboardClientWithOverride, ) diff --git a/setup.py b/setup.py index 3460e20674..e6932eb615 100644 --- a/setup.py +++ b/setup.py @@ -29,10 +29,14 @@ with io.open(readme_filename, encoding="utf-8") as readme_file: readme = readme_file.read() -tensorboard_extra_require = ["tensorflow-cpu >= 2.3.0, <=2.5.0rc"] +tensorboard_extra_require = [ + "tensorflow-cpu >= 2.3.0, <=2.5.0rc", + "grpcio~=1.34.0", + "six~=1.15.0", +] metadata_extra_require = ["pandas >= 1.0.0"] full_extra_require = tensorboard_extra_require + metadata_extra_require -testing_extra_require = full_extra_require + ["grpcio-testing >= 1.37.1"] +testing_extra_require = full_extra_require + ["grpcio-testing ~= 1.34.0"] setuptools.setup( @@ -41,6 +45,11 @@ description=description, long_description=readme, packages=setuptools.PEP420PackageFinder.find(), + entry_points={ + "console_scripts": [ + "tb-gcp-uploader=google.cloud.aiplatform.tensorboard.uploader_main:run_main" + ], + }, namespace_packages=("google", "google.cloud"), author="Google LLC", author_email="googleapis-packages@google.com", diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py new file mode 100644 index 0000000000..c63f729fd3 --- /dev/null +++ b/tests/unit/aiplatform/test_uploader.py @@ -0,0 +1,1454 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for uploader.py.""" + +import logging +import os +import re +from unittest import mock + +import grpc +import grpc_testing +from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 +from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import tensor_pb2 +from tensorboard.compat.proto import types_pb2 +from tensorboard.plugins.scalar import metadata as scalars_metadata +from tensorboard.plugins.graph import metadata as graphs_metadata +from tensorboard.summary import v1 as summary_v1 +from tensorboard.uploader import logdir_loader +from tensorboard.uploader import upload_tracker +from tensorboard.uploader import util +from tensorboard.uploader.proto import server_info_pb2 +import tensorflow as tf + +from google.api_core import datetime_helpers +import google.cloud.aiplatform.tensorboard.uploader as uploader_lib +from google.cloud import storage +from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1 +from google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports import ( + grpc as transports_grpc, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_data_v1beta1 as tensorboard_data, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_experiment_v1beta1 as tensorboard_experiment_type, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_run_v1beta1 as tensorboard_run_type, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_time_series_v1beta1 as tensorboard_time_series_type, +) +from google.protobuf import timestamp_pb2 +from google.protobuf import message + +data_compat = uploader_lib.event_file_loader.data_compat +dataclass_compat = uploader_lib.event_file_loader.dataclass_compat +scalar_v2_pb = summary_v1._scalar_summary.scalar_pb +image_pb = summary_v1._image_summary.pb + +_SCALARS_HISTOGRAMS_AND_GRAPHS = frozenset( + (scalars_metadata.PLUGIN_NAME, graphs_metadata.PLUGIN_NAME,) +) + +# Sentinel for `_create_*` helpers, for arguments for which we want to +# supply a default other than the `None` used by the code under test. +_USE_DEFAULT = object() + +_TEST_EXPERIMENT_NAME = "test-experiment" +_TEST_TENSORBOARD_RESOURCE_NAME = ( + "projects/test_project/locations/us-central1/tensorboards/test_tensorboard" +) +_TEST_LOG_DIR_NAME = "/logs/foo" +_TEST_RUN_NAME = "test-run" +_TEST_ONE_PLATFORM_EXPERIMENT_NAME = "{}/experiments/{}".format( + _TEST_TENSORBOARD_RESOURCE_NAME, _TEST_EXPERIMENT_NAME +) +_TEST_ONE_PLATFORM_RUN_NAME = "{}/runs/{}".format( + _TEST_ONE_PLATFORM_EXPERIMENT_NAME, _TEST_RUN_NAME +) +_TEST_TIME_SERIES_NAME = "test-time-series" +_TEST_ONE_PLATFORM_TIME_SERIES_NAME = "{}/timeSeries/{}".format( + _TEST_ONE_PLATFORM_RUN_NAME, _TEST_TIME_SERIES_NAME +) +_TEST_BLOB_STORAGE_FOLDER = "test_folder" + + +def _create_example_graph_bytes(large_attr_size): + graph_def = graph_pb2.GraphDef() + graph_def.node.add(name="alice", op="Person") + graph_def.node.add(name="bob", op="Person") + + graph_def.node[1].attr["small"].s = b"small_attr_value" + graph_def.node[1].attr["large"].s = b"l" * large_attr_size + graph_def.node.add(name="friendship", op="Friendship", input=["alice", "bob"]) + return graph_def.SerializeToString() + + +class AbortUploadError(Exception): + """Exception used in testing to abort the upload process.""" + + +def _create_mock_client(): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself + # doesn't work with autospec because grpc constructs stubs via metaclassing. + def create_experiment_response( + tensorboard_experiment_id=None, + tensorboard_experiment=None, # pylint: disable=unused-argument + parent=None, + ): # pylint: disable=unused-argument + return tensorboard_experiment_type.TensorboardExperiment( + name=tensorboard_experiment_id + ) + + def create_run_response( + tensorboard_run=None, # pylint: disable=unused-argument + tensorboard_run_id=None, + parent=None, + ): # pylint: disable=unused-argument + return tensorboard_run_type.TensorboardRun(name=tensorboard_run_id) + + def create_tensorboard_time_series( + tensorboard_time_series=None, parent=None + ): # pylint: disable=unused-argument + return tensorboard_time_series_type.TensorboardTimeSeries( + name=tensorboard_time_series.display_name, + display_name=tensorboard_time_series.display_name, + ) + + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time() + ) + mock_client = mock.Mock( + spec=tensorboard_service_client_v1beta1.TensorboardServiceClient( + transport=transports_grpc.TensorboardServiceGrpcTransport( + channel=test_channel + ) + ) + ) + mock_client.create_tensorboard_experiment.side_effect = create_experiment_response + mock_client.create_tensorboard_run.side_effect = create_run_response + mock_client.create_tensorboard_time_series.side_effect = ( + create_tensorboard_time_series + ) + return mock_client + + +def _create_uploader( + writer_client=_USE_DEFAULT, + logdir=None, + max_scalar_request_size=_USE_DEFAULT, + max_tensor_request_size=_USE_DEFAULT, + max_tensor_point_size=_USE_DEFAULT, + max_blob_request_size=_USE_DEFAULT, + max_blob_size=_USE_DEFAULT, + logdir_poll_rate_limiter=_USE_DEFAULT, + rpc_rate_limiter=_USE_DEFAULT, + experiment_name=_TEST_EXPERIMENT_NAME, + tensorboard_resource_name=_TEST_TENSORBOARD_RESOURCE_NAME, + blob_storage_bucket=None, + blob_storage_folder=_TEST_BLOB_STORAGE_FOLDER, + description=None, + verbosity=0, # Use 0 to minimize littering the test output. + one_shot=None, +): + if writer_client is _USE_DEFAULT: + writer_client = _create_mock_client() + if max_scalar_request_size is _USE_DEFAULT: + max_scalar_request_size = 128000 + if max_tensor_request_size is _USE_DEFAULT: + max_tensor_request_size = 512000 + if max_blob_request_size is _USE_DEFAULT: + max_blob_request_size = 128000 + if max_blob_size is _USE_DEFAULT: + max_blob_size = 12345 + if max_tensor_point_size is _USE_DEFAULT: + max_tensor_point_size = 16000 + if logdir_poll_rate_limiter is _USE_DEFAULT: + logdir_poll_rate_limiter = util.RateLimiter(0) + if rpc_rate_limiter is _USE_DEFAULT: + rpc_rate_limiter = util.RateLimiter(0) + + upload_limits = server_info_pb2.UploadLimits( + max_scalar_request_size=max_scalar_request_size, + max_tensor_request_size=max_tensor_request_size, + max_tensor_point_size=max_tensor_point_size, + max_blob_request_size=max_blob_request_size, + max_blob_size=max_blob_size, + ) + + return uploader_lib.TensorBoardUploader( + experiment_name=experiment_name, + tensorboard_resource_name=tensorboard_resource_name, + writer_client=writer_client, + logdir=logdir, + allowed_plugins=_SCALARS_HISTOGRAMS_AND_GRAPHS, + upload_limits=upload_limits, + blob_storage_bucket=blob_storage_bucket, + blob_storage_folder=blob_storage_folder, + logdir_poll_rate_limiter=logdir_poll_rate_limiter, + rpc_rate_limiter=rpc_rate_limiter, + description=description, + verbosity=verbosity, + one_shot=one_shot, + ) + + +def _create_request_sender( + experiment_resource_name, api=None, allowed_plugins=_USE_DEFAULT +): + if api is _USE_DEFAULT: + api = _create_mock_client() + if allowed_plugins is _USE_DEFAULT: + allowed_plugins = _SCALARS_HISTOGRAMS_AND_GRAPHS + + upload_limits = server_info_pb2.UploadLimits( + max_scalar_request_size=128000, + max_tensor_request_size=128000, + max_tensor_point_size=52000, + ) + + rpc_rate_limiter = util.RateLimiter(0) + tensor_rpc_rate_limiter = util.RateLimiter(0) + blob_rpc_rate_limiter = util.RateLimiter(0) + + return uploader_lib._BatchedRequestSender( + experiment_resource_name=experiment_resource_name, + api=api, + allowed_plugins=allowed_plugins, + upload_limits=upload_limits, + rpc_rate_limiter=rpc_rate_limiter, + tensor_rpc_rate_limiter=tensor_rpc_rate_limiter, + blob_rpc_rate_limiter=blob_rpc_rate_limiter, + blob_storage_bucket=None, + blob_storage_folder=None, + tracker=upload_tracker.UploadTracker(verbosity=0), + ) + + +def _create_scalar_request_sender( + run_resource_id, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT +): + if api is _USE_DEFAULT: + api = _create_mock_client() + if max_request_size is _USE_DEFAULT: + max_request_size = 128000 + return uploader_lib._ScalarBatchedRequestSender( + run_resource_id=run_resource_id, + api=api, + rpc_rate_limiter=util.RateLimiter(0), + max_request_size=max_request_size, + tracker=upload_tracker.UploadTracker(verbosity=0), + ) + + +def _scalar_event(tag, value): + return event_pb2.Event(summary=scalar_v2_pb(tag, value)) + + +def _grpc_error(code, details): + # Monkey patch insertion for the methods a real grpc.RpcError would have. + error = grpc.RpcError("RPC error %r: %s" % (code, details)) + error.code = lambda: code + error.details = lambda: details + return error + + +def _timestamp_pb(nanos): + result = timestamp_pb2.Timestamp() + result.FromNanoseconds(nanos) + return result + + +class FileWriter(tf.compat.v1.summary.FileWriter): + """FileWriter for test. + + TensorFlow FileWriter uses TensorFlow's Protobuf Python binding + which is largely discouraged in TensorBoard. We do not want a + TB.Writer but require one for testing in integrational style + (writing out event files and use the real event readers). + """ + + def __init__(self, *args, **kwargs): + # Briefly enter graph mode context so this testing FileWriter can be + # created from an eager mode context without triggering a usage error. + with tf.compat.v1.Graph().as_default(): + super(FileWriter, self).__init__(*args, **kwargs) + + def add_test_summary(self, tag, simple_value=1.0, step=None): + """Convenience for writing a simple summary for a given tag.""" + value = summary_pb2.Summary.Value(tag=tag, simple_value=simple_value) + summary = summary_pb2.Summary(value=[value]) + self.add_summary(summary, global_step=step) + + def add_test_tensor_summary(self, tag, tensor, step=None, value_metadata=None): + """Convenience for writing a simple summary for a given tag.""" + value = summary_pb2.Summary.Value( + tag=tag, tensor=tensor, metadata=value_metadata + ) + summary = summary_pb2.Summary(value=[value]) + self.add_summary(summary, global_step=step) + + def add_event(self, event): + if isinstance(event, event_pb2.Event): + tf_event = tf.compat.v1.Event.FromString(event.SerializeToString()) + else: + tf_event = event + if not isinstance(event, bytes): + logging.error( + "Added TensorFlow event proto. " + "Please prefer TensorBoard copy of the proto" + ) + super(FileWriter, self).add_event(tf_event) + + def add_summary(self, summary, global_step=None): + if isinstance(summary, summary_pb2.Summary): + tf_summary = tf.compat.v1.Summary.FromString(summary.SerializeToString()) + else: + tf_summary = summary + if not isinstance(summary, bytes): + logging.error( + "Added TensorFlow summary proto. " + "Please prefer TensorBoard copy of the proto" + ) + super(FileWriter, self).add_summary(tf_summary, global_step) + + def add_session_log(self, session_log, global_step=None): + if isinstance(session_log, event_pb2.SessionLog): + tf_session_log = tf.compat.v1.SessionLog.FromString( + session_log.SerializeToString() + ) + else: + tf_session_log = session_log + if not isinstance(session_log, bytes): + logging.error( + "Added TensorFlow session_log proto. " + "Please prefer TensorBoard copy of the proto" + ) + super(FileWriter, self).add_session_log(tf_session_log, global_step) + + def add_graph(self, graph, global_step=None, graph_def=None): + if isinstance(graph_def, graph_pb2.GraphDef): + tf_graph_def = tf.compat.v1.GraphDef.FromString( + graph_def.SerializeToString() + ) + else: + tf_graph_def = graph_def + + super(FileWriter, self).add_graph( + graph, global_step=global_step, graph_def=tf_graph_def + ) + + def add_meta_graph(self, meta_graph_def, global_step=None): + if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef): + tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString( + meta_graph_def.SerializeToString() + ) + else: + tf_meta_graph_def = meta_graph_def + + super(FileWriter, self).add_meta_graph( + meta_graph_def=tf_meta_graph_def, global_step=global_step + ) + + +class TensorboardUploaderTest(tf.test.TestCase): + def test_create_experiment(self): + logdir = _TEST_LOG_DIR_NAME + uploader = _create_uploader(_create_mock_client(), logdir) + uploader.create_experiment() + self.assertEqual(uploader._experiment.name, _TEST_EXPERIMENT_NAME) + + def test_create_experiment_with_name(self): + logdir = _TEST_LOG_DIR_NAME + mock_client = _create_mock_client() + new_name = "This is the new name" + uploader = _create_uploader(mock_client, logdir, experiment_name=new_name) + uploader.create_experiment() + mock_client.create_tensorboard_experiment.assert_called_once() + call_args = mock_client.create_tensorboard_experiment.call_args + self.assertEqual( + call_args[1]["tensorboard_experiment"], + tensorboard_experiment_type.TensorboardExperiment(), + ) + self.assertEqual(call_args[1]["parent"], _TEST_TENSORBOARD_RESOURCE_NAME) + self.assertEqual(call_args[1]["tensorboard_experiment_id"], new_name) + + def test_create_experiment_with_description(self): + logdir = _TEST_LOG_DIR_NAME + mock_client = _create_mock_client() + new_description = """ + **description**" + may have "strange" unicode chars 🌴 \\/<> + """ + uploader = _create_uploader(mock_client, logdir, description=new_description) + uploader.create_experiment() + self.assertEqual(uploader._experiment_name, _TEST_EXPERIMENT_NAME) + mock_client.create_tensorboard_experiment.assert_called_once() + call_args = mock_client.create_tensorboard_experiment.call_args + + tb_experiment = tensorboard_experiment_type.TensorboardExperiment( + description=new_description + ) + + expected_call_args = mock.call( + parent=_TEST_TENSORBOARD_RESOURCE_NAME, + tensorboard_experiment_id=_TEST_EXPERIMENT_NAME, + tensorboard_experiment=tb_experiment, + ) + + self.assertEqual(expected_call_args, call_args) + + def test_create_experiment_with_all_metadata(self): + logdir = _TEST_LOG_DIR_NAME + mock_client = _create_mock_client() + new_description = """ + **description**" + may have "strange" unicode chars 🌴 \\/<> + """ + new_name = "This is a cool name." + uploader = _create_uploader( + mock_client, logdir, experiment_name=new_name, description=new_description + ) + uploader.create_experiment() + self.assertEqual(uploader._experiment_name, new_name) + mock_client.create_tensorboard_experiment.assert_called_once() + call_args = mock_client.create_tensorboard_experiment.call_args + + tb_experiment = tensorboard_experiment_type.TensorboardExperiment( + description=new_description + ) + expected_call_args = mock.call( + parent=_TEST_TENSORBOARD_RESOURCE_NAME, + tensorboard_experiment_id=new_name, + tensorboard_experiment=tb_experiment, + ) + self.assertEqual(call_args, expected_call_args) + + def test_start_uploading_without_create_experiment_fails(self): + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, _TEST_LOG_DIR_NAME) + with self.assertRaisesRegex(RuntimeError, "call create_experiment()"): + uploader.start_uploading() + + def test_start_uploading_scalars(self): + mock_client = _create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tensor_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_blob_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tracker = mock.MagicMock() + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ): + uploader = _create_uploader( + writer_client=mock_client, + logdir=_TEST_LOG_DIR_NAME, + # Send each Event below in a separate WriteScalarRequest + max_scalar_request_size=100, + rpc_rate_limiter=mock_rate_limiter, + verbosity=1, # In order to test the upload tracker. + ) + uploader.create_experiment() + + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [_scalar_event("1.1", 5.0), _scalar_event("1.2", 5.0)] + ), + "run 2": _apply_compat( + [_scalar_event("2.1", 5.0), _scalar_event("2.2", 5.0)] + ), + }, + { + "run 3": _apply_compat( + [_scalar_event("3.1", 5.0), _scalar_event("3.2", 5.0)] + ), + "run 4": _apply_compat( + [_scalar_event("4.1", 5.0), _scalar_event("4.2", 5.0)] + ), + "run 5": _apply_compat( + [_scalar_event("5.1", 5.0), _scalar_event("5.2", 5.0)] + ), + }, + AbortUploadError, + ] + + with mock.patch.object( + uploader, "_logdir_loader", mock_logdir_loader + ), self.assertRaises(AbortUploadError): + uploader.start_uploading() + self.assertEqual(10, mock_client.write_tensorboard_run_data.call_count) + self.assertEqual(10, mock_rate_limiter.tick.call_count) + self.assertEqual(0, mock_tensor_rate_limiter.tick.call_count) + self.assertEqual(0, mock_blob_rate_limiter.tick.call_count) + + # Check upload tracker calls. + self.assertEqual(mock_tracker.send_tracker.call_count, 2) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 10) + self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1) + self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) + self.assertEqual(mock_tracker.blob_tracker.call_count, 0) + + def test_start_uploading_scalars_one_shot(self): + """Check that one-shot uploading stops without AbortUploadError.""" + mock_client = _create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tracker = mock.MagicMock() + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ): + uploader = _create_uploader( + writer_client=mock_client, + logdir=_TEST_LOG_DIR_NAME, + # Send each Event below in a separate WriteScalarRequest + max_scalar_request_size=100, + rpc_rate_limiter=mock_rate_limiter, + verbosity=1, # In order to test the upload tracker. + one_shot=True, + ) + uploader.create_experiment() + + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [_scalar_event("1.1", 5.0), _scalar_event("1.2", 5.0)] + ), + "run 2": _apply_compat( + [_scalar_event("2.1", 5.0), _scalar_event("2.2", 5.0)] + ), + }, + # Note the lack of AbortUploadError here. + ] + + with mock.patch.object(uploader, "_logdir_loader", mock_logdir_loader): + uploader.start_uploading() + + self.assertEqual(4, mock_client.write_tensorboard_run_data.call_count) + self.assertEqual(4, mock_rate_limiter.tick.call_count) + + # Check upload tracker calls. + self.assertEqual(mock_tracker.send_tracker.call_count, 1) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 4) + self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1) + self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) + self.assertEqual(mock_tracker.blob_tracker.call_count, 0) + + def test_upload_empty_logdir(self): + logdir = self.get_temp_dir() + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, logdir) + uploader.create_experiment() + uploader._upload_once() + mock_client.write_tensorboard_run_data.assert_not_called() + + def test_upload_polls_slowly_once_done(self): + class SuccessError(Exception): + pass + + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + upload_call_count_box = [0] + + def mock_upload_once(): + upload_call_count_box[0] += 1 + tick_count = mock_rate_limiter.tick.call_count + self.assertEqual(tick_count, upload_call_count_box[0]) + if tick_count >= 3: + raise SuccessError() + + uploader = _create_uploader( + logdir=self.get_temp_dir(), logdir_poll_rate_limiter=mock_rate_limiter, + ) + uploader._upload_once = mock_upload_once + + uploader.create_experiment() + with self.assertRaises(SuccessError): + uploader.start_uploading() + + def test_upload_swallows_rpc_failure(self): + logdir = self.get_temp_dir() + with FileWriter(logdir) as writer: + writer.add_test_summary("foo") + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, logdir) + uploader.create_experiment() + error = _grpc_error(grpc.StatusCode.INTERNAL, "Failure") + mock_client.write_tensorboard_run_data.side_effect = error + uploader._upload_once() + mock_client.write_tensorboard_run_data.assert_called_once() + + def test_upload_full_logdir(self): + logdir = self.get_temp_dir() + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, logdir) + uploader.create_experiment() + + # Convenience helpers for constructing expected requests. + data = tensorboard_data.TimeSeriesData + point = tensorboard_data.TimeSeriesDataPoint + scalar = tensorboard_data.Scalar + + # First round + writer = FileWriter(logdir) + metadata = summary_pb2.SummaryMetadata( + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="scalars", content=b"12345" + ), + data_class=summary_pb2.DATA_CLASS_SCALAR, + ) + writer.add_test_summary("foo", simple_value=5.0, step=1) + writer.add_test_summary("foo", simple_value=6.0, step=2) + writer.add_test_summary("foo", simple_value=7.0, step=3) + writer.add_test_tensor_summary( + "bar", + tensor=tensor_pb2.TensorProto(dtype=types_pb2.DT_FLOAT, float_val=[8.0]), + step=3, + value_metadata=metadata, + ) + writer.flush() + writer_a = FileWriter(os.path.join(logdir, "a")) + writer_a.add_test_summary("qux", simple_value=9.0, step=2) + writer_a.flush() + uploader._upload_once() + self.assertEqual(3, mock_client.create_tensorboard_time_series.call_count) + call_args_list = mock_client.create_tensorboard_time_series.call_args_list + request = call_args_list[1][1]["tensorboard_time_series"] + self.assertEqual("scalars", request.plugin_name) + self.assertEqual(b"12345", request.plugin_data) + + self.assertEqual(2, mock_client.write_tensorboard_run_data.call_count) + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + request1, request2 = ( + call_args_list[0][1]["time_series_data"], + call_args_list[1][1]["time_series_data"], + ) + _clear_wall_times(request1) + _clear_wall_times(request2) + + expected_request1 = [ + data( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + point(step=1, scalar=scalar(value=5.0)), + point(step=2, scalar=scalar(value=6.0)), + point(step=3, scalar=scalar(value=7.0)), + ], + ), + data( + tensorboard_time_series_id="bar", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=3, scalar=scalar(value=8.0))], + ), + ] + expected_request2 = [ + data( + tensorboard_time_series_id="qux", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=2, scalar=scalar(value=9.0))], + ) + ] + self.assertProtoEquals(expected_request1[0], request1[0]) + self.assertProtoEquals(expected_request1[1], request1[1]) + self.assertProtoEquals(expected_request2[0], request2[0]) + + mock_client.write_tensorboard_run_data.reset_mock() + + # Second round + writer.add_test_summary("foo", simple_value=10.0, step=5) + writer.add_test_summary("baz", simple_value=11.0, step=1) + writer.flush() + writer_b = FileWriter(os.path.join(logdir, "b")) + writer_b.add_test_summary("xyz", simple_value=12.0, step=1) + writer_b.flush() + uploader._upload_once() + self.assertEqual(2, mock_client.write_tensorboard_run_data.call_count) + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + request3, request4 = ( + call_args_list[0][1]["time_series_data"], + call_args_list[1][1]["time_series_data"], + ) + _clear_wall_times(request3) + _clear_wall_times(request4) + expected_request3 = [ + data( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=5, scalar=scalar(value=10.0))], + ), + data( + tensorboard_time_series_id="baz", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=11.0))], + ), + ] + expected_request4 = [ + data( + tensorboard_time_series_id="xyz", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=12.0))], + ) + ] + self.assertProtoEquals(expected_request3[0], request3[0]) + self.assertProtoEquals(expected_request3[1], request3[1]) + self.assertProtoEquals(expected_request4[0], request4[0]) + mock_client.write_tensorboard_run_data.reset_mock() + + # Empty third round + uploader._upload_once() + mock_client.write_tensorboard_run_data.assert_not_called() + + def test_verbosity_zero_creates_upload_tracker_with_verbosity_zero(self): + mock_client = _create_mock_client() + mock_tracker = mock.MagicMock() + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ) as mock_constructor: + uploader = _create_uploader( + mock_client, + _TEST_LOG_DIR_NAME, + verbosity=0, # Explicitly set verbosity to 0. + ) + uploader.create_experiment() + + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [_scalar_event("1.1", 5.0), _scalar_event("1.2", 5.0)] + ), + }, + AbortUploadError, + ] + + with mock.patch.object( + uploader, "_logdir_loader", mock_logdir_loader + ), self.assertRaises(AbortUploadError): + uploader.start_uploading() + + self.assertEqual(mock_constructor.call_count, 1) + self.assertEqual(mock_constructor.call_args[1], {"verbosity": 0}) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 1) + + def test_start_uploading_graphs(self): + mock_client = _create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_bucket = mock.create_autospec(storage.Bucket) + mock_blob = mock.create_autospec(storage.Blob) + mock_bucket.blob.return_value = mock_blob + mock_tracker = mock.MagicMock() + + def create_time_series(tensorboard_time_series, parent=None): + return tensorboard_time_series_type.TensorboardTimeSeries( + name=_TEST_ONE_PLATFORM_TIME_SERIES_NAME, + display_name=tensorboard_time_series.display_name, + ) + + mock_client.create_tensorboard_time_series.side_effect = create_time_series + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ): + uploader = _create_uploader( + writer_client=mock_client, + logdir=_TEST_LOG_DIR_NAME, + # Verify behavior with lots of small chunks + max_blob_request_size=100, + rpc_rate_limiter=mock_rate_limiter, + blob_storage_bucket=mock_bucket, + verbosity=1, # In order to test tracker. + ) + uploader.create_experiment() + + # Of course a real Event stream will never produce the same Event twice, + # but is this test context it's fine to reuse this one. + graph_event = event_pb2.Event(graph_def=_create_example_graph_bytes(950)) + expected_graph_def = graph_pb2.GraphDef.FromString(graph_event.graph_def) + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat([graph_event, graph_event]), + "run 2": _apply_compat([graph_event, graph_event]), + }, + { + "run 3": _apply_compat([graph_event, graph_event]), + "run 4": _apply_compat([graph_event, graph_event]), + "run 5": _apply_compat([graph_event, graph_event]), + }, + AbortUploadError, + ] + + with mock.patch.object( + uploader, "_logdir_loader", mock_logdir_loader + ), self.assertRaises(AbortUploadError): + uploader.start_uploading() + + self.assertEqual(1, mock_client.create_tensorboard_experiment.call_count) + self.assertEqual(10, mock_bucket.blob.call_count) + + blob_ids = set() + for call in mock_bucket.blob.call_args_list: + request = call[0][0] + m = re.match( + "test_folder/tensorboard-.*/test-experiment/.*/{}/(.*)".format( + _TEST_TIME_SERIES_NAME + ), + request, + ) + self.assertIsNotNone(m) + blob_ids.add(m[1]) + + for call in mock_blob.upload_from_string.call_args_list: + request = call[0][0] + actual_graph_def = graph_pb2.GraphDef.FromString(request) + self.assertProtoEquals(expected_graph_def, actual_graph_def) + + for call in mock_client.write_tensorboard_run_data.call_args_list: + kargs = call[1] + time_series_data = kargs["time_series_data"] + self.assertEqual(len(time_series_data), 1) + self.assertEqual( + time_series_data[0].tensorboard_time_series_id, _TEST_TIME_SERIES_NAME + ) + self.assertEqual(len(time_series_data[0].values), 1) + blobs = time_series_data[0].values[0].blobs.values + self.assertEqual(len(blobs), 1) + self.assertIn(blobs[0].id, blob_ids) + + # Check upload tracker calls. + self.assertEqual(mock_tracker.send_tracker.call_count, 2) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 0) + self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) + self.assertEqual(mock_tracker.blob_tracker.call_count, 10) + self.assertLen(mock_tracker.blob_tracker.call_args[0], 1) + self.assertGreater(mock_tracker.blob_tracker.call_args[0][0], 0) + + def test_filter_graphs(self): + # Three graphs: one short, one long, one corrupt. + bytes_0 = _create_example_graph_bytes(123) + bytes_1 = _create_example_graph_bytes(9999) + # invalid (truncated) proto: length-delimited field 1 (0x0a) of + # length 0x7f specified, but only len("bogus") = 5 bytes given + # + bytes_2 = b"\x0a\x7fbogus" + + logdir = self.get_temp_dir() + for (i, b) in enumerate([bytes_0, bytes_1, bytes_2]): + run_dir = os.path.join(logdir, "run_%04d" % i) + event = event_pb2.Event(step=0, wall_time=123 * i, graph_def=b) + with FileWriter(run_dir) as writer: + writer.add_event(event) + + limiter = mock.create_autospec(util.RateLimiter) + limiter.tick.side_effect = [None, AbortUploadError] + mock_bucket = mock.create_autospec(storage.Bucket) + mock_blob = mock.create_autospec(storage.Blob) + mock_bucket.blob.return_value = mock_blob + mock_client = _create_mock_client() + + def create_time_series(tensorboard_time_series, parent=None): + return tensorboard_time_series_type.TensorboardTimeSeries( + name=_TEST_ONE_PLATFORM_TIME_SERIES_NAME, + display_name=tensorboard_time_series.display_name, + ) + + mock_client.create_tensorboard_time_series.side_effect = create_time_series + uploader = _create_uploader( + mock_client, + logdir, + logdir_poll_rate_limiter=limiter, + blob_storage_bucket=mock_bucket, + ) + uploader.create_experiment() + + with self.assertRaises(AbortUploadError): + uploader.start_uploading() + + actual_blobs = [] + for call in mock_blob.upload_from_string.call_args_list: + requests = call[0][0] + actual_blobs.append(requests) + + actual_graph_defs = [] + for blob in actual_blobs: + try: + actual_graph_defs.append(graph_pb2.GraphDef.FromString(blob)) + except message.DecodeError: + actual_graph_defs.append(None) + + with self.subTest("graphs with small attr values should be unchanged"): + expected_graph_def_0 = graph_pb2.GraphDef.FromString(bytes_0) + self.assertEqual(actual_graph_defs[0], expected_graph_def_0) + + with self.subTest("large attr values should be filtered out"): + expected_graph_def_1 = graph_pb2.GraphDef.FromString(bytes_1) + del expected_graph_def_1.node[1].attr["large"] + expected_graph_def_1.node[1].attr["_too_large_attrs"].list.s.append( + b"large" + ) + self.assertEqual(actual_graph_defs[1], expected_graph_def_1) + + with self.subTest("corrupt graphs should be skipped"): + self.assertLen(actual_blobs, 2) + + +class BatchedRequestSenderTest(tf.test.TestCase): + def _populate_run_from_events( + self, n_scalar_events, events, allowed_plugins=_USE_DEFAULT + ): + mock_client = _create_mock_client() + builder = _create_request_sender( + experiment_resource_name="123", + api=mock_client, + allowed_plugins=allowed_plugins, + ) + builder.send_requests({"": _apply_compat(events)}) + scalar_requests = mock_client.write_tensorboard_run_data.call_args_list + if scalar_requests: + self.assertLen(scalar_requests, 1) + self.assertLen(scalar_requests[0][1]["time_series_data"], n_scalar_events) + return scalar_requests + + def test_empty_events(self): + call_args_list = self._populate_run_from_events(0, []) + self.assertProtoEquals(call_args_list, []) + + def test_scalar_events(self): + events = [ + event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2_pb("scalar2", 5.0)), + ] + call_args_lists = self._populate_run_from_events(2, events) + scalar_tag_counts = _extract_tag_counts(call_args_lists) + self.assertEqual(scalar_tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_skips_non_scalar_events(self): + events = [ + event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), + event_pb2.Event(file_version="brain.Event:2"), + ] + call_args_list = self._populate_run_from_events(1, events) + scalar_tag_counts = _extract_tag_counts(call_args_list) + self.assertEqual(scalar_tag_counts, {"scalar1": 1}) + + def test_skips_non_scalar_events_in_scalar_time_series(self): + events = [ + event_pb2.Event(file_version="brain.Event:2"), + event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2_pb("scalar2", 5.0)), + ] + call_args_list = self._populate_run_from_events(2, events) + scalar_tag_counts = _extract_tag_counts(call_args_list) + self.assertEqual(scalar_tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_skips_events_from_disallowed_plugins(self): + event = event_pb2.Event( + step=1, wall_time=123.456, summary=scalar_v2_pb("foo", 5.0) + ) + call_args_lists = self._populate_run_from_events( + 0, [event], allowed_plugins=frozenset("not-scalars"), + ) + self.assertEqual(call_args_lists, []) + + def test_remembers_first_metadata_in_time_series(self): + scalar_1 = event_pb2.Event(summary=scalar_v2_pb("loss", 4.0)) + scalar_2 = event_pb2.Event(summary=scalar_v2_pb("loss", 3.0)) + scalar_2.summary.value[0].ClearField("metadata") + events = [ + event_pb2.Event(file_version="brain.Event:2"), + scalar_1, + scalar_2, + ] + call_args_list = self._populate_run_from_events(1, events) + scalar_tag_counts = _extract_tag_counts(call_args_list) + self.assertEqual(scalar_tag_counts, {"loss": 2}) + + def test_expands_multiple_values_in_event(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + event.summary.value.add(tag="foo", simple_value=2.0) + event.summary.value.add(tag="foo", simple_value=3.0) + call_args_list = self._populate_run_from_events(1, [event]) + + time_series_data = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=1.0), + ), + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=2.0), + ), + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=3.0), + ), + ], + ) + + self.assertProtoEquals( + time_series_data, call_args_list[0][1]["time_series_data"][0] + ) + + +class ScalarBatchedRequestSenderTest(tf.test.TestCase): + def _add_events(self, sender, events): + for event in events: + for value in event.summary.value: + sender.add_event(event, value, value.metadata) + + def _add_events_and_flush(self, events, expected_n_time_series): + mock_client = _create_mock_client() + sender = _create_scalar_request_sender( + run_resource_id=_TEST_RUN_NAME, api=mock_client, + ) + self._add_events(sender, events) + sender.flush() + + requests = mock_client.write_tensorboard_run_data.call_args_list + self.assertLen(requests, 1) + self.assertLen(requests[0][1]["time_series_data"], expected_n_time_series) + return requests[0] + + def test_aggregation_by_tag(self): + def make_event(step, wall_time, tag, value): + return event_pb2.Event( + step=step, wall_time=wall_time, summary=scalar_v2_pb(tag, value), + ) + + events = [ + make_event(1, 1.0, "one", 11.0), + make_event(1, 2.0, "two", 22.0), + make_event(2, 3.0, "one", 33.0), + make_event(2, 4.0, "two", 44.0), + make_event(1, 5.0, "one", 55.0), # Should preserve duplicate step=1. + make_event(1, 6.0, "three", 66.0), + ] + call_args = self._add_events_and_flush(events, 3) + ts_data = call_args[1]["time_series_data"] + tag_data = { + ts.tensorboard_time_series_id: [ + ( + value.step, + value.wall_time.timestamp_pb().ToSeconds(), + value.scalar.value, + ) + for value in ts.values + ] + for ts in ts_data + } + self.assertEqual( + tag_data, + { + "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)], + "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)], + "three": [(1, 6.0, 66.0)], + }, + ) + + def test_v1_summary(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=5.0) + call_args = self._add_events_and_flush(_apply_compat([event]), 1) + + expected_call_args = mock.call( + tensorboard_run=_TEST_RUN_NAME, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=5.0), + ) + ], + ) + ], + ) + self.assertEqual(expected_call_args, call_args) + + def test_v1_summary_tb_summary(self): + tf_summary = summary_v1.scalar_pb("foo", 5.0) + tb_summary = summary_pb2.Summary.FromString(tf_summary.SerializeToString()) + event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) + call_args = self._add_events_and_flush(_apply_compat([event]), 1) + + expected_call_args = mock.call( + tensorboard_run=_TEST_RUN_NAME, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="scalar_summary", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=5.0), + ) + ], + ) + ], + ) + self.assertEqual(expected_call_args, call_args) + + def test_v2_summary(self): + event = event_pb2.Event( + step=1, wall_time=123.456, summary=scalar_v2_pb("foo", 5.0) + ) + call_args = self._add_events_and_flush(_apply_compat([event]), 1) + + expected_call_args = mock.call( + tensorboard_run=_TEST_RUN_NAME, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=5.0), + ) + ], + ) + ], + ) + + self.assertEqual(expected_call_args, call_args) + + def test_propagates_experiment_deletion(self): + event = event_pb2.Event(step=1) + event.summary.value.add(tag="foo", simple_value=1.0) + + mock_client = _create_mock_client() + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, _apply_compat([event])) + + error = _grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.write_tensorboard_run_data.side_effect = error + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + sender.flush() + + def test_no_budget_for_base_request(self): + mock_client = _create_mock_client() + long_run_id = "A" * 12 + with self.assertRaises(uploader_lib._OutOfSpaceError) as cm: + _create_scalar_request_sender( + run_resource_id=long_run_id, api=mock_client, max_request_size=12, + ) + self.assertEqual(str(cm.exception), "Byte budget too small for base request") + + def test_no_room_for_single_point(self): + mock_client = _create_mock_client() + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + sender = _create_scalar_request_sender("123", mock_client, max_request_size=12) + with self.assertRaises(RuntimeError) as cm: + self._add_events(sender, [event]) + self.assertEqual(str(cm.exception), "add_event failed despite flush") + + def test_break_at_run_boundary(self): + mock_client = _create_mock_client() + # Choose run name sizes such that one run fits in a 1024 byte request, + # but not two. + long_run_1 = "A" * 768 + long_run_2 = "B" * 768 + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + + sender_1 = _create_scalar_request_sender( + long_run_1, + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + + sender_2 = _create_scalar_request_sender( + long_run_2, + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + self._add_events(sender_1, _apply_compat([event_1])) + self._add_events(sender_2, _apply_compat([event_2])) + sender_1.flush() + sender_2.flush() + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + + for call_args in call_args_list: + _clear_wall_times(call_args[1]["time_series_data"]) + + # Expect two calls despite a single explicit call to flush(). + + expected = [ + mock.call( + tensorboard_run=long_run_1, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, scalar=tensorboard_data.Scalar(value=1.0) + ) + ], + ) + ], + ), + mock.call( + tensorboard_run=long_run_2, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="bar", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=2, scalar=tensorboard_data.Scalar(value=-2.0) + ) + ], + ) + ], + ), + ] + + self.assertEqual(expected[0], call_args_list[0]) + self.assertEqual(expected[1], call_args_list[1]) + + def test_break_at_tag_boundary(self): + mock_client = _create_mock_client() + # Choose tag name sizes such that one tag fits in a 1024 byte request, + # but not two. Note that tag names appear in both `Tag.name` and the + # summary metadata. + long_tag_1 = "a" * 384 + long_tag_2 = "b" * 384 + event = event_pb2.Event(step=1) + event.summary.value.add(tag=long_tag_1, simple_value=1.0) + event.summary.value.add(tag=long_tag_2, simple_value=2.0) + + sender = _create_scalar_request_sender( + "train", + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + self._add_events(sender, _apply_compat([event])) + sender.flush() + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + + request1 = call_args_list[0][1]["time_series_data"] + _clear_wall_times(request1) + + # Convenience helpers for constructing expected requests. + data = tensorboard_data.TimeSeriesData + point = tensorboard_data.TimeSeriesDataPoint + scalar = tensorboard_data.Scalar + + expected_request1 = [ + data( + tensorboard_time_series_id=long_tag_1, + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=1.0))], + ), + data( + tensorboard_time_series_id=long_tag_2, + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=2.0))], + ), + ] + self.assertProtoEquals(expected_request1[0], request1[0]) + self.assertProtoEquals(expected_request1[1], request1[1]) + + def test_break_at_scalar_point_boundary(self): + mock_client = _create_mock_client() + point_count = 2000 # comfortably saturates a single 1024-byte request + events = [] + for step in range(point_count): + summary = scalar_v2_pb("loss", -2.0 * step) + if step > 0: + summary.value[0].ClearField("metadata") + events.append(event_pb2.Event(summary=summary, step=step)) + + sender = _create_scalar_request_sender( + "train", + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + self._add_events(sender, _apply_compat(events)) + sender.flush() + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + + for call_args in call_args_list: + _clear_wall_times(call_args[1]["time_series_data"]) + + self.assertGreater(len(call_args_list), 1) + self.assertLess(len(call_args_list), point_count) + # This is the observed number of requests when running the test. There + # is no reasonable way to derive this value from just reading the code. + # The number of requests does not have to be 37 to be correct but if it + # changes it probably warrants some investigation or thought. + self.assertEqual(37, len(call_args_list)) + + total_points_in_result = 0 + for call_args in call_args_list: + self.assertLen(call_args[1]["time_series_data"], 1) + self.assertEqual(call_args[1]["tensorboard_run"], "train") + time_series_data = call_args[1]["time_series_data"][0] + self.assertEqual(time_series_data.tensorboard_time_series_id, "loss") + for point in time_series_data.values: + self.assertEqual(point.step, total_points_in_result) + self.assertEqual(point.scalar.value, -2.0 * point.step) + total_points_in_result += 1 + self.assertEqual(total_points_in_result, point_count) + + def test_prunes_tags_and_runs(self): + mock_client = _create_mock_client() + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + + add_point_call_count_box = [0] + + def mock_add_point(byte_budget_manager_self, point): + # Simulate out-of-space error the first time that we try to store + # the second point. + add_point_call_count_box[0] += 1 + if add_point_call_count_box[0] == 2: + raise uploader_lib._OutOfSpaceError() + + with mock.patch.object( + uploader_lib._ByteBudgetManager, "add_point", mock_add_point, + ): + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, _apply_compat([event_1])) + self._add_events(sender, _apply_compat([event_2])) + sender.flush() + + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + request1, request2 = ( + call_args_list[0][1]["time_series_data"], + call_args_list[1][1]["time_series_data"], + ) + _clear_wall_times(request1) + _clear_wall_times(request2) + + # Convenience helpers for constructing expected requests. + data = tensorboard_data.TimeSeriesData + point = tensorboard_data.TimeSeriesDataPoint + scalar = tensorboard_data.Scalar + + expected_request1 = [ + data( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=1.0))], + ) + ] + + expected_request2 = [ + data( + tensorboard_time_series_id="bar", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=2, scalar=scalar(value=-2.0))], + ) + ] + self.assertProtoEquals(expected_request1[0], request1[0]) + self.assertProtoEquals(expected_request2[0], request2[0]) + + def test_wall_time_precision(self): + # Test a wall time that is exactly representable in float64 but has enough + # digits to incur error if converted to nanoseconds the naive way (* 1e9). + event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119) + event1.summary.value.add(tag="foo", simple_value=1.0) + # Test a wall time where as a float64, the fractional part on its own will + # introduce error if truncated to 9 decimal places instead of rounded. + event2 = event_pb2.Event(step=2, wall_time=1.000000002) + event2.summary.value.add(tag="foo", simple_value=2.0) + call_args = self._add_events_and_flush(_apply_compat([event1, event2]), 1) + self.assertEqual( + datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb( + _timestamp_pb(1567808404765432119) + ), + call_args[1]["time_series_data"][0].values[0].wall_time, + ) + self.assertEqual( + datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb( + _timestamp_pb(1000000002) + ), + call_args[1]["time_series_data"][0].values[1].wall_time, + ) + + +class VarintCostTest(tf.test.TestCase): + def test_varint_cost(self): + self.assertEqual(uploader_lib._varint_cost(0), 1) + self.assertEqual(uploader_lib._varint_cost(7), 1) + self.assertEqual(uploader_lib._varint_cost(127), 1) + self.assertEqual(uploader_lib._varint_cost(128), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128 - 1), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128), 3) + + +def _clear_wall_times(repeated_time_series_data): + """Clears the wall_time fields in a TimeSeriesData to be deterministic. + + Args: + repeated_time_series_data: Iterable of tensorboard_data.TimeSeriesData. + """ + + for time_series_data in repeated_time_series_data: + for value in time_series_data.values: + value.wall_time = None + + +def _apply_compat(events): + initial_metadata = {} + for event in events: + event = data_compat.migrate_event(event) + events = dataclass_compat.migrate_event( + event, initial_metadata=initial_metadata + ) + for migrated_event in events: + yield migrated_event + + +def _extract_tag_counts(call_args_list): + return { + ts_data.tensorboard_time_series_id: len(ts_data.values) + for call_args in call_args_list + for ts_data in call_args[1]["time_series_data"] + } + + +if __name__ == "__main__": + tf.test.main() From 02d50260ab06876824068fb7486f75022afa0b0d Mon Sep 17 00:00:00 2001 From: Yicheng Fang <58752348+yfang1@users.noreply.github.com> Date: Tue, 11 May 2021 06:01:18 -0700 Subject: [PATCH 25/36] fix: updating TB webserver hostname (#380) --- google/cloud/aiplatform/tensorboard/uploader_main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/aiplatform/tensorboard/uploader_main.py b/google/cloud/aiplatform/tensorboard/uploader_main.py index 60298b5e5c..734d647fb4 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_main.py +++ b/google/cloud/aiplatform/tensorboard/uploader_main.py @@ -129,8 +129,9 @@ def main(argv): tb_uploader.create_experiment() print( - "View your Tensorboard at https://{}/experiment/{}".format( - "tensorboard-gcp-prod.uc.r.appspot.com", + "View your Tensorboard at https://{}.{}/experiment/{}".format( + region, + "tensorboard.googleusercontent.com", tb_uploader.get_experiment_resource_name().replace("/", "+"), ) ) From 7eaedb67a8c8fdbe362ff48430eaab06af8e1605 Mon Sep 17 00:00:00 2001 From: WhiteSource Renovate Date: Wed, 12 May 2021 15:29:37 +0200 Subject: [PATCH 26/36] chore(deps): update dependency google-cloud-aiplatform to v0.8.0 (#381) --- samples/snippets/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index fdf2eb888d..7581fe95ab 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,3 +1,3 @@ pytest==6.2.4 google-cloud-storage>=1.26.0, <2.0.0dev -google-cloud-aiplatform==0.7.1 +google-cloud-aiplatform==0.8.0 From 56273f7d1329a3404e58af4666297e6d6325f6ed Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Wed, 12 May 2021 15:09:46 -0500 Subject: [PATCH 27/36] feat: Add VPC Peering support to CustomTrainingJob classes (#378) * Add 'network' for VPC Peering in custom training * Blacken code --- google/cloud/aiplatform/training_jobs.py | 62 ++++++++++++++++++++- tests/unit/aiplatform/test_training_jobs.py | 7 +++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 8ef054fc97..f3f447deb6 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1526,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir( worker_pool_specs: _DistributedTrainingSpec, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1538,6 +1539,11 @@ def _prepare_training_task_inputs_and_output_dir( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. Returns: Training task inputs and Output directory for custom job. """ @@ -1556,6 +1562,8 @@ def _prepare_training_task_inputs_and_output_dir( if service_account: training_task_inputs["serviceAccount"] = service_account + if network: + training_task_inputs["network"] = network return training_task_inputs, base_output_dir @@ -1803,6 +1811,7 @@ def run( model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, @@ -1891,6 +1900,11 @@ def run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -1981,6 +1995,7 @@ def run( environment_variables=environment_variables, base_output_dir=base_output_dir, service_account=service_account, + network=network, bigquery_destination=bigquery_destination, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, @@ -2008,6 +2023,7 @@ def _run( environment_variables: Optional[Dict[str, str]] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -2061,6 +2077,11 @@ def _run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -2130,7 +2151,10 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir, service_account + worker_pool_specs=worker_pool_specs, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, ) model = self._run_job( @@ -2375,6 +2399,7 @@ def run( model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, @@ -2456,6 +2481,11 @@ def run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -2545,6 +2575,7 @@ def run( environment_variables=environment_variables, base_output_dir=base_output_dir, service_account=service_account, + network=network, bigquery_destination=bigquery_destination, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, @@ -2571,6 +2602,7 @@ def _run( environment_variables: Optional[Dict[str, str]] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, @@ -2621,6 +2653,11 @@ def _run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): The BigQuery project location where the training data is to be written to. In the given project a new dataset is created @@ -2683,7 +2720,10 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir, service_account + worker_pool_specs=worker_pool_specs, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, ) model = self._run_job( @@ -3709,6 +3749,7 @@ def run( model_display_name: Optional[str] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, bigquery_destination: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, environment_variables: Optional[Dict[str, str]] = None, @@ -3790,6 +3831,11 @@ def run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. bigquery_destination (str): Provide this field if `dataset` is a BiqQuery dataset. The BigQuery project location where the training data is to @@ -3874,6 +3920,7 @@ def run( environment_variables=environment_variables, base_output_dir=base_output_dir, service_account=service_account, + network=network, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, test_fraction_split=test_fraction_split, @@ -3900,6 +3947,7 @@ def _run( environment_variables: Optional[Dict[str, str]] = None, base_output_dir: Optional[str] = None, service_account: Optional[str] = None, + network: Optional[str] = None, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -3951,6 +3999,11 @@ def _run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. training_fraction_split (float): The fraction of the input data that is to be used to train the Model. @@ -3999,7 +4052,10 @@ def _run( training_task_inputs, base_output_dir, ) = self._prepare_training_task_inputs_and_output_dir( - worker_pool_specs, base_output_dir, service_account + worker_pool_specs=worker_pool_specs, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, ) model = self._run_job( diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index c3c0e33863..8fd82c7727 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -109,6 +109,7 @@ ) _TEST_ALT_PROJECT = "test-project-alt" _TEST_ALT_LOCATION = "europe-west4" +_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}" _TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml" _TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml" @@ -598,6 +599,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( dataset=mock_tabular_dataset, base_output_dir=_TEST_BASE_OUTPUT_DIR, service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, replica_count=1, @@ -700,6 +702,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, "serviceAccount": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, }, struct_pb2.Value(), ), @@ -2539,6 +2542,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, base_output_dir=_TEST_BASE_OUTPUT_DIR, service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, args=_TEST_RUN_ARGS, replica_count=1, machine_type=_TEST_MACHINE_TYPE, @@ -2621,6 +2625,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset( "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, "serviceAccount": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, }, struct_pb2.Value(), ), @@ -2970,6 +2975,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_display_name=_TEST_MODEL_DISPLAY_NAME, base_output_dir=_TEST_BASE_OUTPUT_DIR, service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, replica_count=1, @@ -3065,6 +3071,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( "workerPoolSpecs": [true_worker_pool_spec], "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, "serviceAccount": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, }, struct_pb2.Value(), ), From cc1a7084f7715c94657d5a3b3374c0fc9a86a299 Mon Sep 17 00:00:00 2001 From: Andrew Ferlitsch Date: Wed, 12 May 2021 14:08:19 -0700 Subject: [PATCH 28/36] feat: Add AutoML vision, Custom training job, and generic prediction samples (#300) * debug mock issue * new mock * more samples * more samples * add next sample/test * add sample/test * run black * Add new Dataset import mocks, fix MBSDK sample tests * Add license headers, update Endpoint mocks/usage * type updates * sasha comment fixes * fix test errors after review update * fix: type for instances * Lint SDK samples * Fix flake8 import order nits Co-authored-by: Vinny Senthil Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> --- ...mage_classification_training_job_sample.py | 54 ++++++++++++++++++ ...classification_training_job_sample_test.py | 56 +++++++++++++++++++ samples/model-builder/conftest.py | 19 +++++++ .../custom_training_job_sample.py | 49 ++++++++++++++++ .../custom_training_job_sample_test.py | 50 +++++++++++++++++ .../model-builder/endpoint_predict_sample.py | 32 +++++++++++ .../endpoint_predict_sample_test.py | 37 ++++++++++++ ...ge_dataset_create_classification_sample.py | 38 +++++++++++++ ...taset_create_classification_sample_test.py | 40 +++++++++++++ ..._dataset_create_object_detection_sample.py | 38 +++++++++++++ ...set_create_object_detection_sample_test.py | 41 ++++++++++++++ .../image_dataset_create_sample.py | 31 ++++++++++ .../image_dataset_create_sample_test.py | 32 +++++++++++ .../image_dataset_import_data_sample.py | 37 ++++++++++++ .../image_dataset_import_data_sample_test.py | 40 +++++++++++++ samples/model-builder/test_constants.py | 5 ++ 16 files changed, 599 insertions(+) create mode 100644 samples/model-builder/automl_image_classification_training_job_sample.py create mode 100644 samples/model-builder/automl_image_classification_training_job_sample_test.py create mode 100644 samples/model-builder/custom_training_job_sample.py create mode 100644 samples/model-builder/custom_training_job_sample_test.py create mode 100644 samples/model-builder/endpoint_predict_sample.py create mode 100644 samples/model-builder/endpoint_predict_sample_test.py create mode 100644 samples/model-builder/image_dataset_create_classification_sample.py create mode 100644 samples/model-builder/image_dataset_create_classification_sample_test.py create mode 100644 samples/model-builder/image_dataset_create_object_detection_sample.py create mode 100644 samples/model-builder/image_dataset_create_object_detection_sample_test.py create mode 100644 samples/model-builder/image_dataset_create_sample.py create mode 100644 samples/model-builder/image_dataset_create_sample_test.py create mode 100644 samples/model-builder/image_dataset_import_data_sample.py create mode 100644 samples/model-builder/image_dataset_import_data_sample_test.py diff --git a/samples/model-builder/automl_image_classification_training_job_sample.py b/samples/model-builder/automl_image_classification_training_job_sample.py new file mode 100644 index 0000000000..502caf008d --- /dev/null +++ b/samples/model-builder/automl_image_classification_training_job_sample.py @@ -0,0 +1,54 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_automl_image_classification_training_job_sample] +def automl_image_classification_training_job_sample( + project: str, location: str, dataset_id: str, display_name: str, +): + aiplatform.init(project=project, location=location) + + dataset = aiplatform.ImageDataset(dataset_id) + + job = aiplatform.AutoMLImageTrainingJob( + display_name=display_name, + prediction_type="classification", + multi_label=False, + model_type="CLOUD", + base_model=None, + ) + + model = job.run( + dataset=dataset, + model_display_name=display_name, + training_fraction_split=0.6, + validation_fraction_split=0.2, + test_fraction_split=0.2, + budget_milli_node_hours=8000, + disable_early_stopping=False, + ) + + print(model.display_name) + print(model.name) + print(model.resource_name) + print(model.description) + print(model.uri) + + return model + + +# [END aiplatform_sdk_automl_image_classification_training_job_sample] diff --git a/samples/model-builder/automl_image_classification_training_job_sample_test.py b/samples/model-builder/automl_image_classification_training_job_sample_test.py new file mode 100644 index 0000000000..a402340f77 --- /dev/null +++ b/samples/model-builder/automl_image_classification_training_job_sample_test.py @@ -0,0 +1,56 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import automl_image_classification_training_job_sample +import test_constants as constants + + +def test_automl_image_classification_training_job_sample( + mock_sdk_init, + mock_image_dataset, + mock_get_image_dataset, + mock_get_automl_image_training_job, + mock_run_automl_image_training_job, +): + automl_image_classification_training_job_sample.automl_image_classification_training_job_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset_id=constants.DATASET_NAME, + display_name=constants.DISPLAY_NAME, + ) + + mock_get_image_dataset.assert_called_once_with(constants.DATASET_NAME) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_automl_image_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + base_model=None, + model_type="CLOUD", + multi_label=False, + prediction_type="classification", + ) + + mock_run_automl_image_training_job.assert_called_once_with( + budget_milli_node_hours=8000, + disable_early_stopping=False, + test_fraction_split=0.2, + training_fraction_split=0.6, + validation_fraction_split=0.2, + model_display_name=constants.DISPLAY_NAME, + dataset=mock_image_dataset, + ) diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 01756f668b..d8c2ed239d 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -124,6 +124,18 @@ def mock_create_video_dataset(mock_video_dataset): """Mocks for SomeDataset.import_data() """ +@pytest.fixture +def mock_import_image_dataset(mock_image_dataset): + with patch.object(mock_image_dataset, "import_data") as mock: + yield mock + + +@pytest.fixture +def mock_import_tabular_dataset(mock_tabular_dataset): + with patch.object(mock_tabular_dataset, "import_data") as mock: + yield mock + + @pytest.fixture def mock_import_text_dataset(mock_text_dataset): with patch.object(mock_text_dataset, "import_data") as mock: @@ -327,6 +339,13 @@ def mock_get_endpoint(mock_endpoint): yield mock_get_endpoint +@pytest.fixture +def mock_endpoint_predict(mock_endpoint): + with patch.object(mock_endpoint, "predict") as mock: + mock.return_value = [] + yield mock + + @pytest.fixture def mock_endpoint_explain(mock_endpoint): with patch.object(mock_endpoint, "explain") as mock_endpoint_explain: diff --git a/samples/model-builder/custom_training_job_sample.py b/samples/model-builder/custom_training_job_sample.py new file mode 100644 index 0000000000..14c874e3a5 --- /dev/null +++ b/samples/model-builder/custom_training_job_sample.py @@ -0,0 +1,49 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_custom_training_job_sample] +def custom_training_job_sample( + project: str, + location: str, + bucket: str, + display_name: str, + script_path: str, + script_args: str, + container_uri: str, + model_serving_container_image_uri: str, + requirements: str, + replica_count: int, +): + aiplatform.init(project=project, location=location, staging_bucket=bucket) + + job = aiplatform.CustomTrainingJob( + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + requirements=requirements, + model_serving_container_image_uri=model_serving_container_image_uri, + ) + + model = job.run( + args=script_args, replica_count=replica_count, model_display_name=display_name + ) + + return model + + +# [END aiplatform_sdk_custom_training_job_sample] diff --git a/samples/model-builder/custom_training_job_sample_test.py b/samples/model-builder/custom_training_job_sample_test.py new file mode 100644 index 0000000000..40d12fb332 --- /dev/null +++ b/samples/model-builder/custom_training_job_sample_test.py @@ -0,0 +1,50 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import custom_training_job_sample +import test_constants as constants + + +def test_custom_training_job_sample( + mock_sdk_init, mock_get_custom_training_job, mock_run_custom_training_job +): + custom_training_job_sample.custom_training_job_sample( + project=constants.PROJECT, + location=constants.LOCATION, + bucket=constants.STAGING_BUCKET, + display_name=constants.DISPLAY_NAME, + script_path=constants.PYTHON_PACKAGE, + script_args=constants.PYTHON_PACKAGE_CMDARGS, + container_uri=constants.TRAIN_IMAGE, + model_serving_container_image_uri=constants.DEPLOY_IMAGE, + requirements=[], + replica_count=1, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, + location=constants.LOCATION, + staging_bucket=constants.STAGING_BUCKET, + ) + + mock_get_custom_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + container_uri=constants.TRAIN_IMAGE, + model_serving_container_image_uri=constants.DEPLOY_IMAGE, + requirements=[], + script_path=constants.PYTHON_PACKAGE, + ) + + mock_run_custom_training_job.assert_called_once() diff --git a/samples/model-builder/endpoint_predict_sample.py b/samples/model-builder/endpoint_predict_sample.py new file mode 100644 index 0000000000..98b7450c51 --- /dev/null +++ b/samples/model-builder/endpoint_predict_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_endpoint_predict_sample] +def endpoint_predict_sample( + project: str, location: str, instances: list, endpoint: str +): + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint) + + prediction = endpoint.predict(instances=instances) + print(prediction) + return prediction + + +# [END aiplatform_sdk_endpoint_predict_sample] diff --git a/samples/model-builder/endpoint_predict_sample_test.py b/samples/model-builder/endpoint_predict_sample_test.py new file mode 100644 index 0000000000..8c2d4e8e10 --- /dev/null +++ b/samples/model-builder/endpoint_predict_sample_test.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import endpoint_predict_sample +import test_constants as constants + + +def test_endpoint_predict_sample( + mock_sdk_init, mock_endpoint_predict, mock_get_endpoint +): + + endpoint_predict_sample.endpoint_predict_sample( + project=constants.PROJECT, + location=constants.LOCATION, + instances=[], + endpoint=constants.ENDPOINT_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME) + + mock_endpoint_predict.assert_called_once_with(instances=[]) diff --git a/samples/model-builder/image_dataset_create_classification_sample.py b/samples/model-builder/image_dataset_create_classification_sample.py new file mode 100644 index 0000000000..ca53cfb7d2 --- /dev/null +++ b/samples/model-builder/image_dataset_create_classification_sample.py @@ -0,0 +1,38 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_create_classification_sample] +def image_dataset_create_classification_sample( + project: str, location: str, display_name: str, src_uris: list +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.image.single_label_classification, + ) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + print(ds.metadata_schema_uri) + return ds + + +# [END aiplatform_sdk_image_dataset_create_classification_sample] diff --git a/samples/model-builder/image_dataset_create_classification_sample_test.py b/samples/model-builder/image_dataset_create_classification_sample_test.py new file mode 100644 index 0000000000..0627d26339 --- /dev/null +++ b/samples/model-builder/image_dataset_create_classification_sample_test.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import image_dataset_create_classification_sample + +import test_constants as constants + + +def test_image_dataset_create_classification_sample( + mock_sdk_init, mock_create_image_dataset +): + image_dataset_create_classification_sample.image_dataset_create_classification_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.image.single_label_classification, + ) diff --git a/samples/model-builder/image_dataset_create_object_detection_sample.py b/samples/model-builder/image_dataset_create_object_detection_sample.py new file mode 100644 index 0000000000..cdcdca009e --- /dev/null +++ b/samples/model-builder/image_dataset_create_object_detection_sample.py @@ -0,0 +1,38 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_create_object_detection_sample] +def image_dataset_create_object_detection_sample( + project: str, location: str, display_name: str, src_uris: list +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.image.bounding_box, + ) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + print(ds.metadata_schema_uri) + return ds + + +# [END aiplatform_sdk_image_dataset_create_object_detection_sample] diff --git a/samples/model-builder/image_dataset_create_object_detection_sample_test.py b/samples/model-builder/image_dataset_create_object_detection_sample_test.py new file mode 100644 index 0000000000..722a0e2a20 --- /dev/null +++ b/samples/model-builder/image_dataset_create_object_detection_sample_test.py @@ -0,0 +1,41 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import image_dataset_create_object_detection_sample + +import test_constants as constants + + +def test_image_dataset_create_object_detection_sample( + mock_sdk_init, mock_create_image_dataset +): + image_dataset_create_object_detection_sample.image_dataset_create_object_detection_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.image.bounding_box, + ) diff --git a/samples/model-builder/image_dataset_create_sample.py b/samples/model-builder/image_dataset_create_sample.py new file mode 100644 index 0000000000..d5821ff7da --- /dev/null +++ b/samples/model-builder/image_dataset_create_sample.py @@ -0,0 +1,31 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_create_sample] +def image_dataset_create_sample(project: str, location: str, display_name: str): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create(display_name=display_name) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_image_dataset_create_sample] diff --git a/samples/model-builder/image_dataset_create_sample_test.py b/samples/model-builder/image_dataset_create_sample_test.py new file mode 100644 index 0000000000..9d04536184 --- /dev/null +++ b/samples/model-builder/image_dataset_create_sample_test.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import image_dataset_create_sample +import test_constants as constants + + +def test_image_dataset_create_sample(mock_sdk_init, mock_create_image_dataset): + image_dataset_create_sample.image_dataset_create_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME + ) diff --git a/samples/model-builder/image_dataset_import_data_sample.py b/samples/model-builder/image_dataset_import_data_sample.py new file mode 100644 index 0000000000..40ca2c75a8 --- /dev/null +++ b/samples/model-builder/image_dataset_import_data_sample.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_image_dataset_import_data_sample] +def image_dataset_import_data_sample( + project: str, location: str, src_uris: list, import_schema_uri: str, dataset_id: str +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset(dataset_id) + + ds = ds.import_data( + gcs_source=src_uris, import_schema_uri=import_schema_uri, sync=True + ) + + print(ds.display_name) + print(ds.name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_image_dataset_import_data_sample] diff --git a/samples/model-builder/image_dataset_import_data_sample_test.py b/samples/model-builder/image_dataset_import_data_sample_test.py new file mode 100644 index 0000000000..e237b115f3 --- /dev/null +++ b/samples/model-builder/image_dataset_import_data_sample_test.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import image_dataset_import_data_sample + +import test_constants as constants + + +def test_image_dataset_import_data_sample( + mock_sdk_init, mock_import_image_dataset, mock_get_image_dataset +): + + image_dataset_import_data_sample.image_dataset_import_data_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + import_schema_uri=None, + dataset_id=constants.DATASET_NAME, + ) + + mock_get_image_dataset.assert_called_once_with(constants.DATASET_NAME) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_import_image_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, import_schema_uri=None, sync=True + ) diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 69da01dbd8..aa92434b95 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -54,6 +54,11 @@ ENCRYPTION_SPEC_KEY_NAME = f"{PARENT}/keyRings/{RESOURCE_ID}/cryptoKeys/{RESOURCE_ID_2}" +PYTHON_PACKAGE = "gs://my-packages/training.tar.gz" +PYTHON_PACKAGE_CMDARGS = f"--model-dir={GCS_DESTINATION}" +TRAIN_IMAGE = "gcr.io/train_image:latest" +DEPLOY_IMAGE = "gcr.io/deploy_image:latest" + PREDICTION_TEXT_INSTANCE = "This is some text for testing NLP prediction output" PREDICTION_TABULAR_CLASSIFICATION_INSTANCE = [ From dcc459d55890961a8aa3cadb696c023a991eea05 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Thu, 13 May 2021 11:04:42 -0400 Subject: [PATCH 29/36] fix: enable aiplatform unit tests --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 916996a6fc..cd85c2b17e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -101,7 +101,7 @@ def default(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", - os.path.join("tests", "unit", "gapic"), + os.path.join("tests", "unit"), *session.posargs, ) From 6bda45755c0de383ae8b1ee5b88ffe984c1ece88 Mon Sep 17 00:00:00 2001 From: WhiteSource Renovate Date: Fri, 14 May 2021 19:40:23 +0200 Subject: [PATCH 30/36] chore(deps): update dependency six to ~=1.16.0 (#389) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a0a1a29bf2..2fecdeeee8 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ tensorboard_extra_require = [ "tensorflow-cpu >= 2.3.0, <=2.5.0rc", "grpcio~=1.34.0", - "six~=1.15.0", + "six~=1.16.0", ] metadata_extra_require = ["pandas >= 1.0.0"] full_extra_require = tensorboard_extra_require + metadata_extra_require From 1b0998881e3f309968c839ba4d3bd5fb3e571312 Mon Sep 17 00:00:00 2001 From: WhiteSource Renovate Date: Fri, 14 May 2021 22:02:44 +0200 Subject: [PATCH 31/36] chore(deps): update dependency tensorflow-cpu to >=2.3.0, <=2.5.0 (#390) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2fecdeeee8..dd9e641f9f 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ readme = readme_file.read() tensorboard_extra_require = [ - "tensorflow-cpu >= 2.3.0, <=2.5.0rc", + "tensorflow-cpu>=2.3.0, <=2.5.0", "grpcio~=1.34.0", "six~=1.16.0", ] From 569d4cd03e888fde0171f7b0060695a14f99b072 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Fri, 14 May 2021 16:07:50 -0700 Subject: [PATCH 32/36] samples: adds batch prediction, training job for text with SDK use cases (#383) * samples: adds batch prediction, training job for text SDK use cases --- samples/model-builder/conftest.py | 13 ++++ .../create_batch_prediction_job_sample.py | 6 ++ ...ing_pipeline_text_classification_sample.py | 64 +++++++++++++++++++ ...ipeline_text_classification_sample_test.py | 60 +++++++++++++++++ ..._pipeline_text_entity_extraction_sample.py | 61 ++++++++++++++++++ ...line_text_entity_extraction_sample_test.py | 58 +++++++++++++++++ ...pipeline_text_sentiment_analysis_sample.py | 64 +++++++++++++++++++ ...ine_text_sentiment_analysis_sample_test.py | 60 +++++++++++++++++ 8 files changed, 386 insertions(+) create mode 100644 samples/model-builder/create_training_pipeline_text_classification_sample.py create mode 100644 samples/model-builder/create_training_pipeline_text_classification_sample_test.py create mode 100644 samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py create mode 100644 samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py create mode 100644 samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py create mode 100644 samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index d8c2ed239d..c6bbd30fc0 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -221,6 +221,19 @@ def mock_run_automl_image_training_job(mock_image_training_job): yield mock +@pytest.fixture +def mock_get_automl_text_training_job(mock_text_training_job): + with patch.object(aiplatform, "AutoMLTextTrainingJob") as mock: + mock.return_value = mock_text_training_job + yield mock + + +@pytest.fixture +def mock_run_automl_text_training_job(mock_text_training_job): + with patch.object(mock_text_training_job, "run") as mock: + yield mock + + @pytest.fixture def mock_get_custom_training_job(mock_custom_training_job): with patch.object(aiplatform, "CustomTrainingJob") as mock: diff --git a/samples/model-builder/create_batch_prediction_job_sample.py b/samples/model-builder/create_batch_prediction_job_sample.py index 9bd5c697a5..cb5a5d3ad8 100644 --- a/samples/model-builder/create_batch_prediction_job_sample.py +++ b/samples/model-builder/create_batch_prediction_job_sample.py @@ -17,6 +17,9 @@ from google.cloud import aiplatform +# [START aiplatform_sdk_create_batch_prediction_job_text_classification_sample] +# [START aiplatform_sdk_create_batch_prediction_job_text_entity_extraction_sample] +# [START aiplatform_sdk_create_batch_prediction_job_text_sentiment_analysis_sample] # [START aiplatform_sdk_create_batch_prediction_job_sample] def create_batch_prediction_job_sample( project: str, @@ -46,4 +49,7 @@ def create_batch_prediction_job_sample( return batch_prediction_job +# [END aiplatform_sdk_create_batch_prediction_job_text_sentiment_analysis_sample] +# [END aiplatform_sdk_create_batch_prediction_job_text_entity_extraction_sample] +# [END aiplatform_sdk_create_batch_prediction_job_text_classification_sample] # [END aiplatform_sdk_create_batch_prediction_job_sample] diff --git a/samples/model-builder/create_training_pipeline_text_classification_sample.py b/samples/model-builder/create_training_pipeline_text_classification_sample.py new file mode 100644 index 0000000000..9306a82084 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_classification_sample.py @@ -0,0 +1,64 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_text_classification_sample] +def create_training_pipeline_text_classification_sample( + project: str, + location: str, + display_name: str, + dataset_id: int, + model_display_name: Optional[str] = None, + multi_label: bool = False, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLTextTrainingJob( + display_name=display_name, + prediction_type="classification", + multi_label=multi_label, + ) + + text_dataset = aiplatform.TextDataset(dataset_id) + + model = job.run( + dataset=text_dataset, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_text_classification_sample] diff --git a/samples/model-builder/create_training_pipeline_text_classification_sample_test.py b/samples/model-builder/create_training_pipeline_text_classification_sample_test.py new file mode 100644 index 0000000000..6f54218e45 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_classification_sample_test.py @@ -0,0 +1,60 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_text_classification_sample +import test_constants as constants + + +def test_create_training_pipeline_text_classification_sample( + mock_sdk_init, + mock_text_dataset, + mock_get_automl_text_training_job, + mock_run_automl_text_training_job, + mock_get_text_dataset, +): + + create_training_pipeline_text_classification_sample.create_training_pipeline_text_classification_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_text_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_text_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + multi_label=False, + prediction_type="classification", + ) + mock_run_automl_text_training_job.assert_called_once_with( + dataset=mock_text_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py new file mode 100644 index 0000000000..2d53cb2d63 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py @@ -0,0 +1,61 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_text_entity_extraction_sample] +def create_training_pipeline_text_entity_extraction_sample( + project: str, + location: str, + display_name: str, + dataset_id: int, + model_display_name: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLTextTrainingJob( + display_name=display_name, prediction_type="extraction" + ) + + text_dataset = aiplatform.TextDataset(dataset_id) + + model = job.run( + dataset=text_dataset, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_text_entity_extraction_sample] diff --git a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py new file mode 100644 index 0000000000..215b123942 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py @@ -0,0 +1,58 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_text_entity_extraction_sample +import test_constants as constants + + +def test_create_training_pipeline_text_clentity_extraction_sample( + mock_sdk_init, + mock_text_dataset, + mock_get_automl_text_training_job, + mock_run_automl_text_training_job, + mock_get_text_dataset, +): + + create_training_pipeline_text_entity_extraction_sample.create_training_pipeline_text_entity_extraction_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_text_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_text_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, prediction_type="extraction" + ) + mock_run_automl_text_training_job.assert_called_once_with( + dataset=mock_text_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py new file mode 100644 index 0000000000..685bed6feb --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py @@ -0,0 +1,64 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_text_sentiment_analysis_sample] +def create_training_pipeline_text_sentiment_analysis_sample( + project: str, + location: str, + display_name: str, + dataset_id: int, + model_display_name: Optional[str] = None, + sentiment_max: int = 10, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLTextTrainingJob( + display_name=display_name, + prediction_type="sentiment", + sentiment_max=sentiment_max, + ) + + text_dataset = aiplatform.TextDataset(dataset_id) + + model = job.run( + dataset=text_dataset, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_text_sentiment_analysis_sample] diff --git a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py new file mode 100644 index 0000000000..6ae5f414bd --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py @@ -0,0 +1,60 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_text_sentiment_analysis_sample +import test_constants as constants + + +def test_create_training_pipeline_text_sentiment_analysis_sample( + mock_sdk_init, + mock_text_dataset, + mock_get_automl_text_training_job, + mock_run_automl_text_training_job, + mock_get_text_dataset, +): + + create_training_pipeline_text_sentiment_analysis_sample.create_training_pipeline_text_sentiment_analysis_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_text_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_text_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + prediction_type="sentiment", + sentiment_max=10, + ) + mock_run_automl_text_training_job.assert_called_once_with( + dataset=mock_text_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) From 47c15300d6c8e879e1d7a10ad0c79e2bb4f18aee Mon Sep 17 00:00:00 2001 From: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Date: Fri, 14 May 2021 19:12:07 -0600 Subject: [PATCH 33/36] fix(deps): add packaging requirement (#392) Add packaging requirement. packaging.version is used for a version comparison in transports/base.py and is needed after the upgrade to gapic-generator-python 0.46.3 --- setup.py | 1 + testing/constraints-3.6.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/setup.py b/setup.py index dd9e641f9f..1887620bd6 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ install_requires=( "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", "proto-plus >= 1.10.1", + "packaging >= 14.3", "google-cloud-storage >= 1.32.0, < 2.0.0dev", "google-cloud-bigquery >= 1.15.0, < 3.0.0dev", ), diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index a247634611..6753ac710d 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -11,3 +11,4 @@ proto-plus==1.10.1 mock==4.0.2 google-cloud-storage==1.32.0 google-auth==1.25.0 # TODO: Remove when google-api-core >= 1.26.0 is required +packaging==14.3 From 066624b7c2ab3af281b7f63e47c990efbcd52673 Mon Sep 17 00:00:00 2001 From: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> Date: Sat, 15 May 2021 19:51:22 -0400 Subject: [PATCH 34/36] fix: rollback six to 1.15 (#391) * fix: remove grpcio and rollback six to 1.15 * fix: add grpcio 1.34 back --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1887620bd6..4a91bfb0ef 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ tensorboard_extra_require = [ "tensorflow-cpu>=2.3.0, <=2.5.0", "grpcio~=1.34.0", - "six~=1.16.0", + "six~=1.15.0", ] metadata_extra_require = ["pandas >= 1.0.0"] full_extra_require = tensorboard_extra_require + metadata_extra_require From f27e17f4fba9a1acb7c583b5dd2e6575bf823c3b Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Mon, 17 May 2021 16:16:07 -0500 Subject: [PATCH 35/36] samples: Add Batch Prediction Job with dedicated resources sample (#396) Manually tested usage with: ```py import os import create_batch_prediction_job_dedicated_resources_sample bpj = create_batch_prediction_job_dedicated_resources_sample.create_batch_prediction_job_dedicated_resources_sample( project=os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT"), location='us-central1', model_resource_name='3512561418744365056', job_display_name='temp_create_batch_prediction_job_test_manual_vinnys', gcs_source='gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl', gcs_destination='gs://ucaip-samples-test-output/', ) ``` --- ...ediction_job_dedicated_resources_sample.py | 59 +++++++++++++++++++ ...ion_job_dedicated_resources_sample_test.py | 55 +++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 samples/model-builder/create_batch_prediction_job_dedicated_resources_sample.py create mode 100644 samples/model-builder/create_batch_prediction_job_dedicated_resources_sample_test.py diff --git a/samples/model-builder/create_batch_prediction_job_dedicated_resources_sample.py b/samples/model-builder/create_batch_prediction_job_dedicated_resources_sample.py new file mode 100644 index 0000000000..bdaf170f10 --- /dev/null +++ b/samples/model-builder/create_batch_prediction_job_dedicated_resources_sample.py @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Union + +from google.cloud import aiplatform, aiplatform_v1 + + +# [START aiplatform_sdk_create_batch_prediction_job_dedicated_resources_sample] +def create_batch_prediction_job_dedicated_resources_sample( + project: str, + location: str, + model_resource_name: str, + job_display_name: str, + gcs_source: Union[str, Sequence[str]], + gcs_destination: str, + machine_type: str = "n1-standard-2", + accelerator_count: int = 1, + accelerator_type: Union[str, aiplatform_v1.AcceleratorType] = "NVIDIA_TESLA_K80", + starting_replica_count: int = 1, + max_replica_count: int = 1, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + my_model = aiplatform.Model(model_resource_name) + + batch_prediction_job = my_model.batch_predict( + job_display_name=job_display_name, + gcs_source=gcs_source, + gcs_destination_prefix=gcs_destination, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + starting_replica_count=starting_replica_count, + max_replica_count=max_replica_count, + sync=sync, + ) + + batch_prediction_job.wait() + + print(batch_prediction_job.display_name) + print(batch_prediction_job.resource_name) + print(batch_prediction_job.state) + return batch_prediction_job + + +# [END aiplatform_sdk_create_batch_prediction_job_dedicated_resources_sample] diff --git a/samples/model-builder/create_batch_prediction_job_dedicated_resources_sample_test.py b/samples/model-builder/create_batch_prediction_job_dedicated_resources_sample_test.py new file mode 100644 index 0000000000..109d368023 --- /dev/null +++ b/samples/model-builder/create_batch_prediction_job_dedicated_resources_sample_test.py @@ -0,0 +1,55 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +import create_batch_prediction_job_dedicated_resources_sample +import test_constants as constants + + +@pytest.mark.usefixtures("mock_model") +def test_create_batch_prediction_job_sample( + mock_sdk_init, mock_init_model, mock_batch_predict_model +): + + create_batch_prediction_job_dedicated_resources_sample.create_batch_prediction_job_dedicated_resources_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_resource_name=constants.MODEL_NAME, + job_display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + gcs_destination=constants.GCS_DESTINATION, + machine_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + accelerator_type=constants.ACCELERATOR_TYPE, + starting_replica_count=constants.MIN_REPLICA_COUNT, + max_replica_count=constants.MAX_REPLICA_COUNT, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_init_model.assert_called_once_with(constants.MODEL_NAME) + mock_batch_predict_model.assert_called_once_with( + job_display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + gcs_destination_prefix=constants.GCS_DESTINATION, + machine_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + accelerator_type=constants.ACCELERATOR_TYPE, + starting_replica_count=constants.MIN_REPLICA_COUNT, + max_replica_count=constants.MAX_REPLICA_COUNT, + sync=True, + ) From 22409c3df5ed028f5f4cdcda02788d896fad883c Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Mon, 17 May 2021 19:43:56 -0400 Subject: [PATCH 36/36] chore: release 0.9.0 (#382) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 16 ++++++++++++++++ setup.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed7a90f32c..02b44159f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [0.9.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.8.0...v0.9.0) (2021-05-17) + + +### Features + +* Add AutoML vision, Custom training job, and generic prediction samples ([#300](https://www.github.com/googleapis/python-aiplatform/issues/300)) ([cc1a708](https://www.github.com/googleapis/python-aiplatform/commit/cc1a7084f7715c94657d5a3b3374c0fc9a86a299)) +* Add VPC Peering support to CustomTrainingJob classes ([#378](https://www.github.com/googleapis/python-aiplatform/issues/378)) ([56273f7](https://www.github.com/googleapis/python-aiplatform/commit/56273f7d1329a3404e58af4666297e6d6325f6ed)) +* AutoML Forecasting, Metadata Experiment Tracking, Tensorboard uploader ([e94c9db](https://www.github.com/googleapis/python-aiplatform/commit/e94c9dbeac701390b25e6d70b0b0acc270636029)) + + +### Bug Fixes + +* **deps:** add packaging requirement ([#392](https://www.github.com/googleapis/python-aiplatform/issues/392)) ([47c1530](https://www.github.com/googleapis/python-aiplatform/commit/47c15300d6c8e879e1d7a10ad0c79e2bb4f18aee)) +* enable aiplatform unit tests ([dcc459d](https://www.github.com/googleapis/python-aiplatform/commit/dcc459d55890961a8aa3cadb696c023a991eea05)) +* rollback six to 1.15 ([#391](https://www.github.com/googleapis/python-aiplatform/issues/391)) ([066624b](https://www.github.com/googleapis/python-aiplatform/commit/066624b7c2ab3af281b7f63e47c990efbcd52673)) + ## [0.8.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.7.1...v0.8.0) (2021-05-11) diff --git a/setup.py b/setup.py index 4a91bfb0ef..a40d87c1da 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ import setuptools # type: ignore name = "google-cloud-aiplatform" -version = "0.8.0" +version = "0.9.0" description = "Cloud AI Platform API client library" package_root = os.path.abspath(os.path.dirname(__file__))